diff --git a/.github/TEST_DATA/template-invalid.html b/.github/TEST_DATA/template-invalid.html new file mode 100644 index 00000000..ac4f6d13 --- /dev/null +++ b/.github/TEST_DATA/template-invalid.html @@ -0,0 +1 @@ +

{{.Title}

diff --git a/app_test.go b/app_test.go index dd872986..11923df4 100644 --- a/app_test.go +++ b/app_test.go @@ -5,12 +5,14 @@ package fiber import ( + "bytes" "crypto/tls" "errors" "fmt" "io" "io/ioutil" "net" + "net/http" "net/http/httptest" "reflect" "regexp" @@ -813,6 +815,23 @@ func Test_Test_Timeout(t *testing.T) { utils.AssertEqual(t, true, err != nil, "app.Test(req)") } +type errorReader int + +func (errorReader) Read([]byte) (int, error) { + return 0, errors.New("errorReader") +} + +func Test_Test_DumpError(t *testing.T) { + app := New() + app.Settings.DisableStartupMessage = true + + app.Get("/", func(_ *Ctx) {}) + + resp, err := app.Test(httptest.NewRequest("GET", "/", errorReader(0))) + utils.AssertEqual(t, true, resp == nil) + utils.AssertEqual(t, "errorReader", err.Error()) +} + func Test_App_Handler(t *testing.T) { h := New().Handler() utils.AssertEqual(t, "fasthttp.RequestHandler", reflect.TypeOf(h).String()) @@ -835,3 +854,116 @@ func Test_App_Init_Error_View(t *testing.T) { }() _ = app.Settings.Views.Render(nil, "", nil) } + +func Test_App_Stack(t *testing.T) { + app := New() + + app.Use("/path0", func(_ *Ctx) {}) + app.Get("/path1", func(_ *Ctx) {}) + app.Get("/path2", func(_ *Ctx) {}) + app.Post("/path3", func(_ *Ctx) {}) + + stack := app.Stack() + utils.AssertEqual(t, 9, len(stack)) + utils.AssertEqual(t, 3, len(stack[methodInt(MethodGet)])) + utils.AssertEqual(t, 3, len(stack[methodInt(MethodHead)])) + utils.AssertEqual(t, 2, len(stack[methodInt(MethodPost)])) + utils.AssertEqual(t, 1, len(stack[methodInt(MethodPut)])) + utils.AssertEqual(t, 1, len(stack[methodInt(MethodPatch)])) + utils.AssertEqual(t, 1, len(stack[methodInt(MethodDelete)])) + utils.AssertEqual(t, 1, len(stack[methodInt(MethodConnect)])) + utils.AssertEqual(t, 1, len(stack[methodInt(MethodOptions)])) + utils.AssertEqual(t, 1, len(stack[methodInt(MethodTrace)])) +} + +// go test -run Test_App_ReadTimeout +//func Test_App_ReadTimeout(t *testing.T) { +// app := New(&Settings{ +// ReadTimeout: time.Nanosecond, +// IdleTimeout: time.Minute, +// DisableStartupMessage: true, +// DisableKeepalive: true, +// }) +// +// app.Get("/read-timeout", func(c *Ctx) { +// c.SendString("I should not be sent") +// }) +// +// go func() { +// time.Sleep(500 * time.Millisecond) +// +// conn, err := net.Dial("tcp4", "127.0.0.1:4004") +// utils.AssertEqual(t, nil, err) +// defer conn.Close() +// +// _, err = conn.Write([]byte("HEAD /read-timeout HTTP/1.1\r\n")) +// utils.AssertEqual(t, nil, err) +// +// buf := make([]byte, 1024) +// var n int +// n, err = conn.Read(buf) +// +// utils.AssertEqual(t, nil, err) +// utils.AssertEqual(t, true, bytes.Contains(buf[:n], []byte("408 Request Timeout"))) +// +// utils.AssertEqual(t, nil, app.Shutdown()) +// }() +// +// utils.AssertEqual(t, nil, app.Listen(4004)) +//} + +// go test -run Test_App_BadRequest +func Test_App_BadRequest(t *testing.T) { + app := New(&Settings{ + DisableStartupMessage: true, + }) + + app.Get("/bad-request", func(c *Ctx) { + c.SendString("I should not be sent") + }) + + go func() { + time.Sleep(500 * time.Millisecond) + conn, err := net.Dial("tcp4", "127.0.0.1:4005") + utils.AssertEqual(t, nil, err) + defer conn.Close() + + _, err = conn.Write([]byte("BadRequest\r\n")) + utils.AssertEqual(t, nil, err) + + buf := make([]byte, 1024) + var n int + n, err = conn.Read(buf) + utils.AssertEqual(t, nil, err) + + utils.AssertEqual(t, true, bytes.Contains(buf[:n], []byte("400 Bad Request"))) + + utils.AssertEqual(t, nil, app.Shutdown()) + }() + + utils.AssertEqual(t, nil, app.Listen(4005)) +} + +// go test -run Test_App_SmallReadBuffer +func Test_App_SmallReadBuffer(t *testing.T) { + app := New(&Settings{ + ReadBufferSize: 1, + DisableStartupMessage: true, + }) + + app.Get("/small-read-buffer", func(c *Ctx) { + c.SendString("I should not be sent") + }) + + go func() { + time.Sleep(500 * time.Millisecond) + resp, err := http.Get("http://127.0.0.1:4006/small-read-buffer") + if resp != nil { + utils.AssertEqual(t, 431, resp.StatusCode) + } + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, nil, app.Shutdown()) + }() + + utils.AssertEqual(t, nil, app.Listen(4006)) +} diff --git a/ctx.go b/ctx.go index 8c2c84ff..f3f4d940 100644 --- a/ctx.go +++ b/ctx.go @@ -15,7 +15,6 @@ import ( "log" "mime/multipart" "net/http" - "os" "path/filepath" "strconv" "strings" @@ -791,15 +790,7 @@ func (ctx *Ctx) Render(name string, bind interface{}, layouts ...string) (err er } else { // Render raw template using 'name' as filepath if no engine is set var tmpl *template.Template - // Read file - f, err := os.Open(filepath.Clean(name)) - if err != nil { - return err - } - if _, err = buf.ReadFrom(f); err != nil { - return err - } - if err = f.Close(); err != nil { + if _, err = readContent(buf, name); err != nil { return err } // Parse template diff --git a/ctx_test.go b/ctx_test.go index 9ce93b20..d10225c1 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -10,6 +10,7 @@ package fiber import ( "bufio" "bytes" + "errors" "fmt" "io" "io/ioutil" @@ -1259,13 +1260,20 @@ func Test_Ctx_SendFile(t *testing.T) { utils.AssertEqual(t, StatusNotModified, ctx.Fasthttp.Response.StatusCode()) utils.AssertEqual(t, []byte(nil), ctx.Fasthttp.Response.Body()) app.ReleaseCtx(ctx) +} - // test 404 - // ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) - // err = ctx.SendFile("./john_doe.go") - // // check expectation - // utils.AssertEqual(t, StatusNotFound, ctx.Fasthttp.Response.StatusCode()) - // app.ReleaseCtx(ctx) +// go test -race -run Test_Ctx_SendFile_404 +func Test_Ctx_SendFile_404(t *testing.T) { + t.Parallel() + app := New() + app.Get("/", func(ctx *Ctx) { + err := ctx.SendFile("./john_dow.go/") + utils.AssertEqual(t, false, err == nil) + }) + + resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, StatusNotFound, resp.StatusCode) } // go test -race -run Test_Ctx_SendFile_Immutable @@ -1493,7 +1501,10 @@ func Test_Ctx_Render(t *testing.T) { utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "

Hello, World!

", string(ctx.Fasthttp.Response.Body())) - err = ctx.Render("./.github/TEST_DATA/invalid.html", nil) + err = ctx.Render("./.github/TEST_DATA/template-non-exists.html", nil) + utils.AssertEqual(t, false, err == nil) + + err = ctx.Render("./.github/TEST_DATA/template-invalid.html", nil) utils.AssertEqual(t, false, err == nil) } @@ -1526,6 +1537,7 @@ func Test_Ctx_Render_Engine(t *testing.T) { utils.AssertEqual(t, "

Hello, World!

", string(ctx.Fasthttp.Response.Body())) } +// go test -v -run=^$ -bench=Benchmark_Ctx_Render_Engine -benchmem -count=4 func Benchmark_Ctx_Render_Engine(b *testing.B) { engine := &testTemplateEngine{} err := engine.Load() @@ -1545,19 +1557,43 @@ func Benchmark_Ctx_Render_Engine(b *testing.B) { utils.AssertEqual(b, "

Hello, World!

", string(ctx.Fasthttp.Response.Body())) } +type errorTemplateEngine struct{} + +func (t errorTemplateEngine) Render(w io.Writer, name string, bind interface{}, layout ...string) error { + return errors.New("errorTemplateEngine") +} + +func (t errorTemplateEngine) Load() error { return nil } + +// go test -run Test_Ctx_Render_Engine_Error +func Test_Ctx_Render_Engine_Error(t *testing.T) { + app := New() + app.Settings.Views = errorTemplateEngine{} + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + err := ctx.Render("index.tmpl", nil) + utils.AssertEqual(t, false, err == nil) +} + // go test -run Test_Ctx_Render_Go_Template func Test_Ctx_Render_Go_Template(t *testing.T) { t.Parallel() + file, err := ioutil.TempFile(os.TempDir(), "fiber") utils.AssertEqual(t, nil, err) defer os.Remove(file.Name()) + _, err = file.Write([]byte("template")) utils.AssertEqual(t, nil, err) + err = file.Close() utils.AssertEqual(t, nil, err) + app := New() + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) + err = ctx.Render(file.Name(), nil) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "template", string(ctx.Fasthttp.Response.Body())) diff --git a/utils.go b/utils.go index 903c9389..9852f834 100644 --- a/utils.go +++ b/utils.go @@ -8,7 +8,10 @@ import ( "bytes" "fmt" "hash/crc32" + "io" "net" + "os" + "path/filepath" "strings" "time" @@ -17,6 +20,19 @@ import ( fasthttp "github.com/valyala/fasthttp" ) +// readContent opens a named file and read content from it +func readContent(rf io.ReaderFrom, name string) (n int64, err error) { + // Read file + f, err := os.Open(filepath.Clean(name)) + if err != nil { + return 0, err + } + defer func() { + err = f.Close() + }() + return rf.ReadFrom(f) +} + // quoteString escape special characters in a given string func quoteString(raw string) string { bb := bytebufferpool.Get() @@ -248,7 +264,6 @@ func (a testAddr) String() string { } type testConn struct { - net.Conn r bytes.Buffer w bytes.Buffer }