diff --git a/listen.go b/listen.go index 1ba0f1c4..8d465840 100644 --- a/listen.go +++ b/listen.go @@ -36,29 +36,7 @@ func (r *Fiber) Listen(address interface{}, tls ...string) { log.Fatal("Listen: Host must be an INT port or STRING address") } // Create fasthttp server - server := &fasthttp.Server{ - Handler: r.handler, - Name: r.Server, - Concurrency: r.Engine.Concurrency, - DisableKeepalive: r.Engine.DisableKeepAlive, - ReadBufferSize: r.Engine.ReadBufferSize, - WriteBufferSize: r.Engine.WriteBufferSize, - ReadTimeout: r.Engine.ReadTimeout, - WriteTimeout: r.Engine.WriteTimeout, - IdleTimeout: r.Engine.IdleTimeout, - MaxConnsPerIP: r.Engine.MaxConnsPerIP, - MaxRequestsPerConn: r.Engine.MaxRequestsPerConn, - TCPKeepalive: r.Engine.TCPKeepalive, - TCPKeepalivePeriod: r.Engine.TCPKeepalivePeriod, - MaxRequestBodySize: r.Engine.MaxRequestBodySize, - ReduceMemoryUsage: r.Engine.ReduceMemoryUsage, - GetOnly: r.Engine.GetOnly, - DisableHeaderNamesNormalizing: r.Engine.DisableHeaderNamesNormalizing, - SleepWhenConcurrencyLimitsExceeded: r.Engine.SleepWhenConcurrencyLimitsExceeded, - NoDefaultServerHeader: r.Server == "", - NoDefaultContentType: r.Engine.NoDefaultContentType, - KeepHijackedConns: r.Engine.KeepHijackedConns, - } + server := r.setupServer() // Prefork enabled if r.Prefork && runtime.NumCPU() > 1 { @@ -137,3 +115,29 @@ func (r *Fiber) prefork(server *fasthttp.Server, host string, tls ...string) { log.Fatal("Listen-prefork: ", err) } } + +func (r *Fiber) setupServer() *fasthttp.Server { + return &fasthttp.Server{ + Handler: r.handler, + Name: r.Server, + Concurrency: r.Engine.Concurrency, + DisableKeepalive: r.Engine.DisableKeepAlive, + ReadBufferSize: r.Engine.ReadBufferSize, + WriteBufferSize: r.Engine.WriteBufferSize, + ReadTimeout: r.Engine.ReadTimeout, + WriteTimeout: r.Engine.WriteTimeout, + IdleTimeout: r.Engine.IdleTimeout, + MaxConnsPerIP: r.Engine.MaxConnsPerIP, + MaxRequestsPerConn: r.Engine.MaxRequestsPerConn, + TCPKeepalive: r.Engine.TCPKeepalive, + TCPKeepalivePeriod: r.Engine.TCPKeepalivePeriod, + MaxRequestBodySize: r.Engine.MaxRequestBodySize, + ReduceMemoryUsage: r.Engine.ReduceMemoryUsage, + GetOnly: r.Engine.GetOnly, + DisableHeaderNamesNormalizing: r.Engine.DisableHeaderNamesNormalizing, + SleepWhenConcurrencyLimitsExceeded: r.Engine.SleepWhenConcurrencyLimitsExceeded, + NoDefaultServerHeader: r.Server == "", + NoDefaultContentType: r.Engine.NoDefaultContentType, + KeepHijackedConns: r.Engine.KeepHijackedConns, + } +} diff --git a/request_test.go b/request_test.go index da442e49..84280a91 100644 --- a/request_test.go +++ b/request_test.go @@ -5,6 +5,7 @@ import ( "fmt" "mime/multipart" "net/http" + "net/http/httptest" "net/url" "strconv" "strings" @@ -14,19 +15,26 @@ import ( func Test_Accepts(t *testing.T) { app := New() app.Get("/test", func(c *Ctx) { - c.Accepts() - expect := ".xml" + expect := "" result := c.Accepts(expect) + if c.Accepts() != "" { + t.Fatalf(`Expecting %s, got %s`, expect, result) + } + expect = ".xml" + result = c.Accepts(expect) if result != expect { - t.Fatalf(`%s: Expecting %s, got %s`, t.Name(), expect, result) + t.Fatalf(`Expecting %s, got %s`, expect, result) } }) - req, _ := http.NewRequest("GET", "/test", nil) + req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8") - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_AcceptsCharsets(t *testing.T) { app := New() @@ -41,10 +49,13 @@ func Test_AcceptsCharsets(t *testing.T) { }) req, _ := http.NewRequest("GET", "/test", nil) req.Header.Set("Accept-Charset", "utf-8, iso-8859-1;q=0.5") - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_AcceptsEncodings(t *testing.T) { app := New() @@ -58,10 +69,13 @@ func Test_AcceptsEncodings(t *testing.T) { }) req, _ := http.NewRequest("GET", "/test", nil) req.Header.Set("Accept-Encoding", "deflate, gzip;q=1.0, *;q=0.5") - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_AcceptsLanguages(t *testing.T) { app := New() @@ -75,10 +89,13 @@ func Test_AcceptsLanguages(t *testing.T) { }) req, _ := http.NewRequest("GET", "/test", nil) req.Header.Set("Accept-Language", "fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5") - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_BaseURL(t *testing.T) { app := New() @@ -91,10 +108,13 @@ func Test_BaseURL(t *testing.T) { } }) req, _ := http.NewRequest("GET", "http://google.com/test", nil) - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_BasicAuth(t *testing.T) { app := New() @@ -111,10 +131,13 @@ func Test_BasicAuth(t *testing.T) { }) req, _ := http.NewRequest("GET", "/test", nil) req.SetBasicAuth("john", "doe") - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_Body(t *testing.T) { app := New() @@ -151,10 +174,13 @@ func Test_Body(t *testing.T) { req, _ := http.NewRequest("POST", "/test", strings.NewReader(data.Encode())) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") req.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_Cookies(t *testing.T) { app := New() @@ -217,10 +243,13 @@ func Test_FormValue(t *testing.T) { req.Header.Set("Content-Type", contentType) req.Header.Set("Content-Length", strconv.Itoa(len(body.Bytes()))) - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_Fresh(t *testing.T) { app := New() @@ -256,10 +285,13 @@ func Test_Get(t *testing.T) { req, _ := http.NewRequest("GET", "/test", nil) req.Header.Set("Accept-Charset", "utf-8, iso-8859-1;q=0.5") req.Header.Set("Referer", "Cookie") - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_Hostname(t *testing.T) { app := New() @@ -271,10 +303,13 @@ func Test_Hostname(t *testing.T) { } }) req, _ := http.NewRequest("GET", "http://google.com/test", nil) - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_IP(t *testing.T) { app := New() @@ -287,10 +322,13 @@ func Test_IP(t *testing.T) { } }) req, _ := http.NewRequest("GET", "http://google.com/test", nil) - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_IPs(t *testing.T) { app := New() @@ -304,10 +342,13 @@ func Test_IPs(t *testing.T) { }) req, _ := http.NewRequest("GET", "/test", nil) req.Header.Set("X-Forwarded-For", "0.0.0.0, 1.1.1.1") - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_Is(t *testing.T) { app := New() @@ -321,10 +362,13 @@ func Test_Is(t *testing.T) { }) req, _ := http.NewRequest("GET", "/test", nil) req.Header.Set("Content-Type", "text/html") - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_Locals(t *testing.T) { app := New() @@ -340,10 +384,13 @@ func Test_Locals(t *testing.T) { } }) req, _ := http.NewRequest("GET", "/test", nil) - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_Method(t *testing.T) { app := New() @@ -369,20 +416,29 @@ func Test_Method(t *testing.T) { } }) req, _ := http.NewRequest("GET", "/test", nil) - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } req, _ = http.NewRequest("POST", "/test", nil) - _, err = app.Test(req) + resp, err = app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } req, _ = http.NewRequest("PUT", "/test", nil) - _, err = app.Test(req) + resp, err = app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_MultipartForm(t *testing.T) { app := New() @@ -407,10 +463,13 @@ func Test_MultipartForm(t *testing.T) { req.Header.Set("Content-Type", contentType) req.Header.Set("Content-Length", strconv.Itoa(len(body.Bytes()))) - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_OriginalURL(t *testing.T) { app := New() @@ -423,10 +482,13 @@ func Test_OriginalURL(t *testing.T) { } }) req, _ := http.NewRequest("GET", "http://google.com/test?search=demo", nil) - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_Params(t *testing.T) { app := New() @@ -445,15 +507,21 @@ func Test_Params(t *testing.T) { } }) req, _ := http.NewRequest("GET", "/test/john", nil) - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } req, _ = http.NewRequest("GET", "/test2/im/a/cookie", nil) - _, err = app.Test(req) + resp, err = app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_Path(t *testing.T) { app := New() @@ -465,10 +533,13 @@ func Test_Path(t *testing.T) { } }) req, _ := http.NewRequest("GET", "/test/john", nil) - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_Query(t *testing.T) { app := New() @@ -485,10 +556,13 @@ func Test_Query(t *testing.T) { } }) req, _ := http.NewRequest("GET", "/test?search=john&age=20", nil) - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_Range(t *testing.T) { app := New() @@ -511,10 +585,13 @@ func Test_Route(t *testing.T) { } }) req, _ := http.NewRequest("GET", "/test", nil) - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_SaveFile(t *testing.T) { // TODO @@ -529,10 +606,13 @@ func Test_Secure(t *testing.T) { } }) req, _ := http.NewRequest("GET", "/test", nil) - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_SignedCookies(t *testing.T) { app := New() @@ -540,10 +620,13 @@ func Test_SignedCookies(t *testing.T) { c.SignedCookies() }) req, _ := http.NewRequest("GET", "/test", nil) - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_Stale(t *testing.T) { app := New() @@ -566,10 +649,13 @@ func Test_Subdomains(t *testing.T) { } }) req, _ := http.NewRequest("GET", "http://john.doe.google.com/test", nil) - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } func Test_XHR(t *testing.T) { app := New() @@ -583,8 +669,11 @@ func Test_XHR(t *testing.T) { }) req, _ := http.NewRequest("GET", "/test", nil) req.Header.Set("X-Requested-With", "XMLHttpRequest") - _, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } } diff --git a/response_test.go b/response_test.go index a7daa5e9..d677455e 100644 --- a/response_test.go +++ b/response_test.go @@ -2,7 +2,6 @@ package fiber import ( "net/http" - "strings" "testing" ) @@ -13,11 +12,14 @@ func Test_Append(t *testing.T) { c.Append("X-Test", "lo", "world") }) req, _ := http.NewRequest("GET", "/test", nil) - res, err := app.Test(req) + resp, err := app.Test(req) if err != nil { t.Fatalf(`%s: %s`, t.Name(), err) } - if !strings.Contains(res, "X-Test: hel, lo, world") { + if resp.StatusCode != 200 { + t.Fatalf(`%s: StatusCode %v`, t.Name(), resp.StatusCode) + } + if resp.Header.Get("X-Test") != "hel, lo, world" { t.Fatalf(`%s: Expecting %s`, t.Name(), "X-Test: hel, lo, world") } } diff --git a/utils.go b/utils.go index 386c3a58..aa084385 100644 --- a/utils.go +++ b/utils.go @@ -8,6 +8,7 @@ package fiber import ( + "bufio" "bytes" "fmt" "io/ioutil" @@ -21,8 +22,6 @@ import ( "strings" "time" "unsafe" - - "github.com/valyala/fasthttp" ) func getParams(path string) (params []string) { @@ -77,6 +76,9 @@ func getFiles(root string) (files []string, isDir bool, err error) { } func getType(ext string) (mime string) { + if ext == "" { + return mime + } if ext[0] == '.' { ext = ext[1:] } @@ -109,108 +111,68 @@ func getBytes(s string) (b []byte) { return b } -// FakeRequest is the same as Test -func (r *Fiber) FakeRequest(req interface{}) (string, error) { - return r.Test(req) -} - -// Test creates a readWriter and calls ServeConn on local servver -func (r *Fiber) Test(req interface{}) (string, error) { - raw := "" - switch r := req.(type) { - case string: - raw = r - case *http.Request: - d, err := httputil.DumpRequest(r, true) - if err != nil { - return "", err - } - raw = getString(d) +// Test takes a http.Request and execute a fake connection to the application +// It returns a http.Response when the connection was successfull +func (r *Fiber) Test(req *http.Request) (*http.Response, error) { + // Get raw http request + reqRaw, err := httputil.DumpRequest(req, true) + if err != nil { + return nil, err } - server := &fasthttp.Server{ - Handler: r.handler, - Name: r.Server, - Concurrency: r.Engine.Concurrency, - DisableKeepalive: r.Engine.DisableKeepAlive, - ReadBufferSize: r.Engine.ReadBufferSize, - WriteBufferSize: r.Engine.WriteBufferSize, - ReadTimeout: r.Engine.ReadTimeout, - WriteTimeout: r.Engine.WriteTimeout, - IdleTimeout: r.Engine.IdleTimeout, - MaxConnsPerIP: r.Engine.MaxConnsPerIP, - MaxRequestsPerConn: r.Engine.MaxRequestsPerConn, - TCPKeepalive: r.Engine.TCPKeepalive, - TCPKeepalivePeriod: r.Engine.TCPKeepalivePeriod, - MaxRequestBodySize: r.Engine.MaxRequestBodySize, - ReduceMemoryUsage: r.Engine.ReduceMemoryUsage, - GetOnly: r.Engine.GetOnly, - DisableHeaderNamesNormalizing: r.Engine.DisableHeaderNamesNormalizing, - SleepWhenConcurrencyLimitsExceeded: r.Engine.SleepWhenConcurrencyLimitsExceeded, - NoDefaultServerHeader: r.Server == "", - NoDefaultContentType: r.Engine.NoDefaultContentType, - KeepHijackedConns: r.Engine.KeepHijackedConns, - } - rw := &readWriter{} - rw.r.WriteString(raw) - - ch := make(chan error) + // Setup a fiber server struct + server := r.setupServer() + // Create fake connection + conn := &conn{} + // Pass HTTP request to conn + conn.r.Write(reqRaw) + // Serve conn to server + channel := make(chan error) go func() { - ch <- server.ServeConn(rw) + channel <- server.ServeConn(conn) }() - + // Wait for callback select { - case err := <-ch: + case err := <-channel: if err != nil { - return "", err + return nil, err } + // Throw timeout error after 200ms case <-time.After(200 * time.Millisecond): - return "", fmt.Errorf("Timeout") + return nil, fmt.Errorf("Timeout") } - - err := server.ServeConn(rw) + // Get raw HTTP response + respRaw, err := ioutil.ReadAll(&conn.w) if err != nil { - return "", err + return nil, err } - resp, err := ioutil.ReadAll(&rw.w) + // Create buffer + reader := strings.NewReader(getString(respRaw)) + buffer := bufio.NewReader(reader) + // Convert raw HTTP response to http.Response + resp, err := http.ReadResponse(buffer, req) if err != nil { - return "", err + return nil, err } - return getString(resp), nil + // Return *http.Response + return resp, nil } -// Readwriter for test cases -type readWriter struct { +// https://golang.org/src/net/net.go#L113 +type conn struct { net.Conn r bytes.Buffer w bytes.Buffer } -func (rw *readWriter) Close() error { - return nil -} - -func (rw *readWriter) Read(b []byte) (int, error) { - return rw.r.Read(b) -} - -func (rw *readWriter) Write(b []byte) (int, error) { - return rw.w.Write(b) -} - -func (rw *readWriter) RemoteAddr() net.Addr { +func (c *conn) RemoteAddr() net.Addr { return &net.TCPAddr{ - IP: net.IPv4zero, + IP: net.IPv4(0, 0, 0, 0), } } - -func (rw *readWriter) LocalAddr() net.Addr { - return rw.RemoteAddr() -} - -func (rw *readWriter) SetReadDeadline(t time.Time) error { - return nil -} - -func (rw *readWriter) SetWriteDeadline(t time.Time) error { - return nil -} +func (c *conn) LocalAddr() net.Addr { return c.LocalAddr() } +func (c *conn) Read(b []byte) (int, error) { return c.r.Read(b) } +func (c *conn) Write(b []byte) (int, error) { return c.w.Write(b) } +func (c *conn) Close() error { return nil } +func (c *conn) SetDeadline(t time.Time) error { return nil } +func (c *conn) SetReadDeadline(t time.Time) error { return nil } +func (c *conn) SetWriteDeadline(t time.Time) error { return nil }