diff --git a/app.go b/app.go index b80a7e38..5e3d5a24 100644 --- a/app.go +++ b/app.go @@ -53,7 +53,7 @@ type App struct { // Route stack divided by HTTP methods stack [][]*Route // Amount of registered routes - routes int + routesCount int // Ctx pool pool sync.Pool // Fasthttp server @@ -307,7 +307,12 @@ func (app *App) Use(args ...interface{}) Router { // Get registers a route for GET methods that requests a representation // of the specified resource. Requests using GET should only retrieve data. func (app *App) Get(path string, handlers ...Handler) Router { - return app.Add(MethodGet, path, handlers...) + route := app.register(MethodGet, path, handlers...) + // Add HEAD route + headRoute := route + app.addRoute(MethodHead, &headRoute) + + return app } // Head registers a route for HEAD methods that asks for a response identical @@ -372,7 +377,7 @@ func (app *App) Static(prefix, root string, config ...Static) Router { // All ... func (app *App) All(path string, handlers ...Handler) Router { for _, method := range intMethod { - _ = app.Add(method, path, handlers...) + app.Add(method, path, handlers...) } return app } @@ -406,26 +411,21 @@ func NewError(code int, message ...string) *Error { func (app *App) Routes() []*Route { routes := make([]*Route, 0) for m := range app.stack { + stackLoop: for r := range app.stack[m] { - // Ignore HEAD routes handling GET routes - if m == 1 && app.stack[m][r].Method == MethodGet { - continue - } - // Don't duplicate USE routes + + // Don't duplicate USE routesCount if app.stack[m][r].Method == methodUse { - duplicate := false for i := range routes { if routes[i].Method == methodUse && routes[i].Name == app.stack[m][r].Name { - duplicate = true - break + continue stackLoop } } - if !duplicate { - routes = append(routes, app.stack[m][r]) - } - } else { - routes = append(routes, app.stack[m][r]) + // Ignore HEAD routes handling GET routesCount + } else if m != methodInt(app.stack[m][r].Method) { + continue } + routes = append(routes, app.stack[m][r]) } } // Sort routes by stack position diff --git a/app_test.go b/app_test.go index a221a117..d1af9286 100644 --- a/app_test.go +++ b/app_test.go @@ -31,6 +31,18 @@ func testStatus200(t *testing.T, app *App, url string, method string) { utils.AssertEqual(t, 200, resp.StatusCode, "Status code") } +func checkRouteCount(t *testing.T, app *App, expectedCount int) { + realStackCount := 0 + for _, routes := range app.stack { + for range routes { + realStackCount++ + } + } + + utils.AssertEqual(t, expectedCount, app.routesCount) + utils.AssertEqual(t, expectedCount, realStackCount) +} + func Test_App_MethodNotAllowed(t *testing.T) { app := New() @@ -107,12 +119,32 @@ func Test_App_Routes(t *testing.T) { app := New() h := func(c *Ctx) {} app.Use("/", h) + app.Use("/", h) app.Get("/Get", h) app.Head("/Head", h) app.Post("/post", h) utils.AssertEqual(t, 4, len(app.Routes())) } +// go test -v -run=^$ -bench=Benchmark_App_Routes -benchmem -count=4 +func Benchmark_App_Routes(b *testing.B) { + app := New() + h := func(c *Ctx) {} + app.Use("/", h) + app.Use("/", h) + app.Get("/Get", h) + app.Head("/Head", h) + app.Post("/post", h) + + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + app.Routes() + } + utils.AssertEqual(b, 4, len(app.Routes())) +} + func Test_App_ServerErrorHandler_SmallReadBuffer(t *testing.T) { expectedError := regexp.MustCompile( `error when reading request headers: small read buffer\. Increase ReadBufferSize\. Buffer size=4096, contents: "GET / HTTP/1.1\\r\\nHost: example\.com\\r\\nVery-Long-Header: -+`, @@ -320,6 +352,8 @@ func Test_App_Chaining(t *testing.T) { app.Use("/john", n, n, n, n, func(c *Ctx) { c.Status(202) }) + // check handler count for registered HEAD route + utils.AssertEqual(t, 5, len(app.stack[methodInt(MethodHead)][0].Handlers), "app.Test(req)") req := httptest.NewRequest("POST", "/john", nil) @@ -406,6 +440,87 @@ func Test_App_Methods(t *testing.T) { } +func Test_App_RegisteredRouteCount(t *testing.T) { + var dummyHandler = func(c *Ctx) {} + + app := New() + app.All("/:john?/:doe?", dummyHandler) + testStatus200(t, app, "/john/doe", MethodGet) + checkRouteCount(t, app, len(intMethod)) + + app = New() + app.Get("/:john?/:doe?", dummyHandler) + app.Head("/:john?/:doe?", dummyHandler) + testStatus200(t, app, "/john/doe", MethodGet) + checkRouteCount(t, app, 2) + + app = New() + app.Head("/:john?/:doe?", dummyHandler) + app.Get("/:john?/:doe?", dummyHandler) + testStatus200(t, app, "/john/doe", MethodGet) + checkRouteCount(t, app, 2) + + app = New() + app.Get("/:john?/:doe?", dummyHandler) + testStatus200(t, app, "/john/doe", MethodGet) + checkRouteCount(t, app, 2) + + app = New() + app.Head("/:john?/:doe?", dummyHandler) + testStatus200(t, app, "/john/doe", MethodHead) + checkRouteCount(t, app, 1) + + app = New() + app.Delete("/:john?/:doe?", dummyHandler) + testStatus200(t, app, "/john/doe", MethodDelete) + checkRouteCount(t, app, 1) + + // with use + app = New() + app.Use("/:john?/:doe?", dummyHandler) + testStatus200(t, app, "/john/doe", MethodPut) + checkRouteCount(t, app, len(intMethod)) + + // with group + app = New() + app.Group("/:john?/:doe?", dummyHandler).Put("/wtf", dummyHandler) + testStatus200(t, app, "/john/doe/wtf", MethodPut) + checkRouteCount(t, app, len(intMethod)+1) + + app = New() + app.Use("/", dummyHandler) + app.All("/bar", dummyHandler) + app.Get("/foo", dummyHandler) + app.Head("/foo", dummyHandler) + checkRouteCount(t, app, len(intMethod)*2+2) +} + +func Test_App_RoutePositions(t *testing.T) { + var dummyHandler = func(c *Ctx) {} + + app := New() + app.Use("/", dummyHandler) + app.All("/bar", dummyHandler) + app.Get("/foo", dummyHandler) + app.Head("/foo", dummyHandler) + testStatus200(t, app, "/foo", MethodGet) + + expectedPos := 1 + // check USE routes + for p := range intMethod { + utils.AssertEqual(t, expectedPos, app.stack[p][0].pos) + expectedPos++ + } + // check ALL routes + for p := range intMethod { + utils.AssertEqual(t, expectedPos, app.stack[p][1].pos) + expectedPos++ + } + // check GET and HEAD route + utils.AssertEqual(t, expectedPos, app.stack[methodInt(MethodGet)][2].pos) + utils.AssertEqual(t, expectedPos+1, app.stack[methodInt(MethodHead)][2].pos) +} + func Test_App_New(t *testing.T) { app := New() app.Get("/", func(*Ctx) { diff --git a/group.go b/group.go index 24b7335d..9727a431 100644 --- a/group.go +++ b/group.go @@ -42,7 +42,11 @@ func (grp *Group) Use(args ...interface{}) Router { // Get registers a route for GET methods that requests a representation // of the specified resource. Requests using GET should only retrieve data. func (grp *Group) Get(path string, handlers ...Handler) Router { - return grp.Add(MethodGet, path, handlers...) + route := grp.app.register(MethodGet, getGroupPath(grp.prefix, path), handlers...) + // Add head route + headRoute := route + grp.app.addRoute(MethodHead, &headRoute) + return grp } // Head registers a route for HEAD methods that asks for a response identical @@ -107,7 +111,7 @@ func (grp *Group) Static(prefix, root string, config ...Static) Router { // All ... func (grp *Group) All(path string, handlers ...Handler) Router { for _, method := range intMethod { - _ = grp.Add(method, path, handlers...) + grp.Add(method, path, handlers...) } return grp } diff --git a/router.go b/router.go index 886459ad..d93afa3b 100644 --- a/router.go +++ b/router.go @@ -142,7 +142,7 @@ func (app *App) handler(rctx *fasthttp.RequestCtx) { app.ReleaseCtx(ctx) } -func (app *App) register(method, pathRaw string, handlers ...Handler) { +func (app *App) register(method, pathRaw string, handlers ...Handler) Route { // Uppercase HTTP methods method = utils.ToUpper(method) // Check if the HTTP method is valid unless it's USE @@ -181,14 +181,9 @@ func (app *App) register(method, pathRaw string, handlers ...Handler) { var parsedRaw = parseRoute(pathRaw) var parsedPretty = parseRoute(pathPretty) - // Increment global route position - app.mutex.Lock() - app.routes++ - app.mutex.Unlock() - // Create route metadata - route := &Route{ + // Create route metadata without pointer + route := Route{ // Router booleans - pos: app.routes, use: isUse, star: isStar, root: isRoot, @@ -206,21 +201,19 @@ func (app *App) register(method, pathRaw string, handlers ...Handler) { if isUse { // Add route to all HTTP methods stack for _, m := range intMethod { - app.addRoute(m, route) + // create a route copy + r := route + app.addRoute(m, &r) } - return - } - - // Handle GET routes on HEAD requests - if method == MethodGet { - app.addRoute(MethodHead, route) + return route } // Add route to stack - app.addRoute(method, route) + app.addRoute(method, &route) + return route } -func (app *App) registerStatic(prefix, root string, config ...Static) { +func (app *App) registerStatic(prefix, root string, config ...Static) Route { // For security we want to restrict to the current work directory. if len(root) == 0 { root = "." @@ -305,22 +298,24 @@ func (app *App) registerStatic(prefix, root string, config ...Static) { // Next middleware c.Next() } - // Increment global route position - app.mutex.Lock() - app.routes++ - app.mutex.Unlock() - route := &Route{ - pos: app.routes, - use: true, - root: isRoot, - path: prefix, - Method: MethodGet, - Path: prefix, + + // Create route metadata without pointer + route := Route{ + // Router booleans + use: true, + root: isRoot, + path: prefix, + // Public data + Method: MethodGet, + Path: prefix, + Handlers: []Handler{handler}, } - route.Handlers = append(route.Handlers, handler) // Add route to stack - app.addRoute(MethodGet, route) - app.addRoute(MethodHead, route) + app.addRoute(MethodGet, &route) + // Add HEAD route + headRoute := route + app.addRoute(MethodHead, &headRoute) + return route } func (app *App) addRoute(method string, route *Route) { @@ -330,6 +325,19 @@ func (app *App) addRoute(method string, route *Route) { } // Get unique HTTP method indentifier m := methodInt(method) - // Add route to the stack - app.stack[m] = append(app.stack[m], route) + + // prevent identically route registration + l := len(app.stack[m]) + if l > 0 && app.stack[m][l-1].Path == route.Path && route.use == app.stack[m][l-1].use { + preRoute := app.stack[m][l-1] + preRoute.Handlers = append(preRoute.Handlers, route.Handlers...) + } else { + // Increment global route position + app.mutex.Lock() + app.routesCount++ + app.mutex.Unlock() + route.pos = app.routesCount + // Add route to the stack + app.stack[m] = append(app.stack[m], route) + } } diff --git a/router_test.go b/router_test.go index b1d88f04..ece8309e 100644 --- a/router_test.go +++ b/router_test.go @@ -341,6 +341,29 @@ func Benchmark_Router_Chain(b *testing.B) { } } +// go test -v ./... -run=^$ -bench=Benchmark_Router_WithCompression -benchmem -count=4 +func Benchmark_Router_WithCompression(b *testing.B) { + app := New() + handler := func(c *Ctx) { + c.Next() + } + app.Get("/", handler) + app.Get("/", handler) + app.Get("/", handler) + app.Get("/", handler) + app.Get("/", handler) + app.Get("/", handler) + + c := &fasthttp.RequestCtx{} + + c.Request.Header.SetMethod("GET") + c.URI().SetPath("/") + + for n := 0; n < b.N; n++ { + app.handler(c) + } +} + // go test -v ./... -run=^$ -bench=Benchmark_Router_Next -benchmem -count=4 func Benchmark_Router_Next(b *testing.B) { app := New()