1
0
mirror of https://github.com/gofiber/fiber.git synced 2025-02-23 21:43:48 +00:00
This commit is contained in:
Kiyon 2021-03-10 08:28:51 +08:00
commit a36899c641
34 changed files with 2529 additions and 286 deletions

2
.github/README.md vendored
View File

@ -153,7 +153,7 @@ func main() {
})
// GET /john/75
app.Get("/:name/:age/:gender?", func(c *fiber.Ctx) error {
app.Get("/:name/:age", func(c *fiber.Ctx) error {
msg := fmt.Sprintf("👴 %s is %s years old", c.Params("name"), c.Params("age"))
return c.SendString(msg) // => 👴 john is 75 years old
})

View File

@ -153,7 +153,7 @@ func main() {
})
// GET /john/75
app.Get("/:name/:age/:gender?", func(c *fiber.Ctx) error {
app.Get("/:name/:age", func(c *fiber.Ctx) error {
msg := fmt.Sprintf("👴 %s is %s years old", c.Params("name"), c.Params("age"))
return c.SendString(msg) // => 👴 john is 75 years old
})

View File

@ -150,7 +150,7 @@ func main() {
})
// GET /john/75
app.Get("/:name/:age/:gender?", func(c *fiber.Ctx) error {
app.Get("/:name/:age", func(c *fiber.Ctx) error {
msg := fmt.Sprintf("👴 %s is %s years old", c.Params("name"), c.Params("age"))
return c.SendString(msg) // => 👴 john is 75 years old
})

View File

@ -153,7 +153,7 @@ func main() {
})
// GET /john/75
app.Get("/:name/:age/:gender?", func(c *fiber.Ctx) error {
app.Get("/:name/:age", func(c *fiber.Ctx) error {
msg := fmt.Sprintf("👴 %s is %s years old", c.Params("name"), c.Params("age"))
return c.SendString(msg) // => 👴 john is 75 years old
})

View File

@ -206,7 +206,7 @@ func main() {
})
// GET /john/75
app.Get("/:name/:age/:gender?", func(c *fiber.Ctx) error {
app.Get("/:name/:age", func(c *fiber.Ctx) error {
msg := fmt.Sprintf("👴 %s is %s years old", c.Params("name"), c.Params("age"))
return c.SendString(msg) // => 👴 john is 75 years old
})

View File

@ -153,7 +153,7 @@ func main() {
})
// GET /john/75
app.Get("/:name/:age/:gender?", func(c *fiber.Ctx) error {
app.Get("/:name/:age", func(c *fiber.Ctx) error {
msg := fmt.Sprintf("👴 %s is %s years old", c.Params("name"), c.Params("age"))
return c.SendString(msg) // => 👴 john is 75 years old
})

View File

@ -156,7 +156,7 @@ func main() {
})
// GET /john/75
app.Get("/:name/:age/:gender?", func(c *fiber.Ctx) error {
app.Get("/:name/:age", func(c *fiber.Ctx) error {
msg := fmt.Sprintf("👴 %s is %s years old", c.Params("name"), c.Params("age"))
return c.SendString(msg) // => 👴 john is 75 years old
})

View File

@ -157,7 +157,7 @@ func main() {
})
// GET /john/75
app.Get("/:name/:age/:gender?", func(c *fiber.Ctx) error {
app.Get("/:name/:age", func(c *fiber.Ctx) error {
msg := fmt.Sprintf("👴 %s is %s years old", c.Params("name"), c.Params("age"))
return c.SendString(msg) // => 👴 john is 75 years old
})

View File

@ -157,7 +157,7 @@ func main() {
})
// GET /john/75
app.Get("/:name/:age/:gender?", func(c *fiber.Ctx) error {
app.Get("/:name/:age", func(c *fiber.Ctx) error {
msg := fmt.Sprintf("👴 %s is %s years old", c.Params("name"), c.Params("age"))
return c.SendString(msg) // => 👴 john is 75 years old
})

View File

@ -153,7 +153,7 @@ func main() {
})
// GET /john/75
app.Get("/:name/:age/:gender?", func(c *fiber.Ctx) error {
app.Get("/:name/:age", func(c *fiber.Ctx) error {
msg := fmt.Sprintf("👴 %s is %s years old", c.Params("name"), c.Params("age"))
return c.SendString(msg) // => 👴 john is 75 years old
})

51
.github/README_ru.md vendored
View File

@ -153,7 +153,7 @@ func main() {
})
// GET /john/75
app.Get("/:name/:age/:gender?", func(c *fiber.Ctx) error {
app.Get("/:name/:age", func(c *fiber.Ctx) error {
msg := fmt.Sprintf("👴 %s is %s years old", c.Params("name"), c.Params("age"))
return c.SendString(msg) // => 👴 john is 75 years old
})
@ -207,27 +207,27 @@ func main() {
```go
func main() {
app := fiber.New()
app := fiber.New()
// Match any route
app.Use(func(c *fiber.Ctx) error {
fmt.Println("🥇 First handler")
return c.Next()
})
// Match any route
app.Use(func(c *fiber.Ctx) error {
fmt.Println("🥇 First handler")
return c.Next()
})
// Match all routes starting with /api
app.Use("/api", func(c *fiber.Ctx) error {
fmt.Println("🥈 Second handler")
return c.Next()
})
// Match all routes starting with /api
app.Use("/api", func(c *fiber.Ctx) error {
fmt.Println("🥈 Second handler")
return c.Next()
})
// GET /api/register
app.Get("/api/list", func(c *fiber.Ctx) error {
fmt.Println("🥉 Last handler")
return c.SendString("Hello, World 👋!")
})
// GET /api/register
app.Get("/api/list", func(c *fiber.Ctx) error {
fmt.Println("🥉 Last handler")
return c.SendString("Hello, World 👋!")
})
log.Fatal(app.Listen(":3000"))
log.Fatal(app.Listen(":3000"))
}
```
@ -248,8 +248,6 @@ func main() {
Ознакомьтесь с пакетом [Template](https://github.com/gofiber/template), который поддерживает множество движков для views.
```go
package main
import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/template/pug"
@ -271,6 +269,7 @@ func main() {
log.Fatal(app.Listen(":3000"))
}
```
### Группировка путей в цепочки
@ -313,8 +312,6 @@ func main() {
📖 [Logger](https://docs.gofiber.io/middleware/logger)
```go
package main
import (
"log"
@ -331,6 +328,7 @@ func main() {
log.Fatal(app.Listen(":3000"))
}
```
### Cross-Origin Resource Sharing (CORS)
@ -354,6 +352,7 @@ func main() {
log.Fatal(app.Listen(":3000"))
}
```
Проверем CORS, присвоив домен в заголовок `Origin`, отличный от `localhost`:
@ -388,6 +387,7 @@ func main() {
log.Fatal(app.Listen(":3000"))
}
```
### JSON Response
@ -418,6 +418,7 @@ func main() {
log.Fatal(app.Listen(":3000"))
}
```
### WebSocket Upgrade
@ -440,7 +441,9 @@ func main() {
log.Println("read:", err)
break
}
log.Printf("recv: %s", msg)
err = c.WriteMessage(mt, msg)
if err != nil {
log.Println("write:", err)
@ -450,8 +453,9 @@ func main() {
}))
log.Fatal(app.Listen(":3000"))
// ws://localhost:3000/ws
// => ws://localhost:3000/ws
}
```
### Recover middleware
@ -475,6 +479,7 @@ func main() {
log.Fatal(app.Listen(":3000"))
}
```
</details>

View File

@ -172,7 +172,7 @@ func main() {
})
// GET /john/75
app.Get("/:name/:age/:gender?", func(c *fiber.Ctx) error {
app.Get("/:name/:age", func(c *fiber.Ctx) error {
msg := fmt.Sprintf("👴 %s is %s years old", c.Params("name"), c.Params("age"))
return c.SendString(msg) // => 👴 john is 75 years old
})

28
.github/README_tr.md vendored
View File

@ -73,7 +73,7 @@
</a>
</p>
<p align="center">
<b>Fiber</b>, <a href="https://golang.org/doc/">Go</a> için <b>en hızlı</b> HTTP motoru olan <a href="https://github.com/valyala/fasthttp">Fasthttp</a> üzerine inşa edilmiş, <a href="https://github.com/expressjs/express">Express</a> den ilham alan bir <b>web çatısıdır</b>. <b>Sıfır bellek ayırma</b> ve <b>performans</b> göz önünde bulundurularak <b>hızlı</b> geliştirme için işleri <b>kolaylaştırmak</b> üzere tasarlandı.
<b>Fiber</b>, <a href="https://golang.org/doc/">Go</a> için <b>en hızlı</b> HTTP motoru olan <a href="https://github.com/valyala/fasthttp">Fasthttp</a> üzerine inşa edilmiş, <a href="https://github.com/expressjs/express">Express</a>den ilham alan bir <b>web çatısıdır</b>. <b>Sıfır bellek ayırma</b> ve <b>performans</b> göz önünde bulundurularak <b>hızlı</b> geliştirme için işleri <b>kolaylaştırmak</b> üzere tasarlandı.
</p>
## ⚡️ Hızlı Başlangıç
@ -96,7 +96,7 @@ func main() {
## 🤖 Performans Ölçümleri
Bu testler [TechEmpower](https://www.techempower.com/benchmarks/#section=data-r19&hw=ph&test=plaintext) ve [Go Web](https://github.com/smallnest/go-web-framework-benchmark) ile koşuldu. Bütün sonuçları görmek için lütfen [Wiki](https://docs.gofiber.io/benchmarks) sayfasını ziyaret ediniz.
Bu testler [TechEmpower](https://www.techempower.com/benchmarks/#section=data-r19&hw=ph&test=plaintext) ve [Go Web](https://github.com/smallnest/go-web-framework-benchmark) ile gerçekleştirildi. Bütün sonuçları görmek için lütfen [Wiki](https://docs.gofiber.io/benchmarks) sayfasını ziyaret ediniz.
<p float="left" align="middle">
<img src="https://raw.githubusercontent.com/gofiber/docs/master/.gitbook/assets/benchmark-pipeline.png" width="49%">
@ -105,9 +105,9 @@ Bu testler [TechEmpower](https://www.techempower.com/benchmarks/#section=data-r1
## ⚙️ Kurulum
Make sure you have Go installed ([download](https://golang.org/dl/)). Version `1.14` or higher is required.
Go'nun `1.14` sürümü ([indir](https://golang.org/dl/)) ya da daha yüksek bir sürüm gerekli.
Initialize your project by creating a folder and then running `go mod init github.com/your/repo` ([learn more](https://blog.golang.org/using-go-modules)) inside the folder. Then install Fiber with the [`go get`](https://golang.org/cmd/go/#hdr-Add_dependencies_to_current_module_and_install_them) command:
Bir klasör oluşturup klasörün içinde `go mod init github.com/your/repo` yazarak projenize başlayın ([daha fazla öğren](https://blog.golang.org/using-go-modules)). Ardından Fiberi kurmak için [`go get`](https://golang.org/cmd/go/#hdr-Add_dependencies_to_current_module_and_install_them) komutunu çalıştırın:
```bash
go get -u github.com/gofiber/fiber/v2
@ -125,12 +125,12 @@ go get -u github.com/gofiber/fiber/v2
- [Template engines](https://github.com/gofiber/template)
- [WebSocket support](https://github.com/gofiber/websocket)
- [Rate Limiter](https://docs.gofiber.io/middleware#limiter)
- Available in [15 languages](https://docs.gofiber.io/)
- Ve daha fazlası, [Fiber ı keşfet](https://docs.gofiber.io/)
- [15 dilde](https://docs.gofiber.io/) mevcut
- Ve daha fazlası, [Fiber'ı keşfet](https://docs.gofiber.io/)
## 💡 Felsefe
[Node.js](https://nodejs.org/en/about/) den [Go](https://golang.org/doc/) ya geçen yeni gopher lar kendi web uygulamalarını ve mikroservislerini yazmaya başlamadan önce dili öğrenmek ile uğraşıyorlar. Fiber, bir **web çatısı** olarak, **minimalizm** ve **UNIX yolu**nu izlemek fikri ile oluşturuldu. Böylece yeni gopher lar sıcak ve güvenilir bir hoşgeldin ile Go dünyasına giriş yapabilirler.
[Node.js](https://nodejs.org/en/about/) den [Go](https://golang.org/doc/) ya geçen yeni gopher lar kendi web uygulamalarını ve mikroservislerini yazmaya başlamadan önce dili öğrenmek ile uğraşıyorlar. Fiber, bir **web çatısı** olarak, **minimalizm** ve **UNIX yolu**nu izlemek fikri ile oluşturuldu. Böylece yeni gopherlar sıcak ve güvenilir bir hoşgeldin ile Go dünyasına giriş yapabilirler.
Fiber internet üzerinde en popüler olan Express web çatısından **esinlenmiştir**. Biz Express in **kolaylığını** ve Go nun **ham performansını** birleştirdik. Daha önce Node.js üzerinde (Express veya benzerini kullanarak) bir web uygulaması geliştirdiyseniz, pek çok metod ve prensip size **çok tanıdık** gelecektir.
@ -151,7 +151,7 @@ func main() {
})
// GET /john/75
app.Get("/:name/:age/:gender?", func(c *fiber.Ctx) error {
app.Get("/:name/:age", func(c *fiber.Ctx) error {
msg := fmt.Sprintf("👴 %s is %s years old", c.Params("name"), c.Params("age"))
return c.SendString(msg) // => 👴 john is 75 years old
})
@ -478,9 +478,9 @@ func main() {
## 🧬 Internal Middleware
Here is a list of middleware that are included within the Fiber framework.
Fibera dahil edilen middlewareların bir listesi aşağıda verilmiştir.
| Middleware | Description |
| Middleware | ıklama |
| :------------------------------------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [basicauth](https://github.com/gofiber/fiber/tree/master/middleware/basicauth) | Basic auth middleware provides an HTTP basic authentication. It calls the next handler for valid credentials and 401 Unauthorized for missing or invalid credentials. |
| [compress](https://github.com/gofiber/fiber/tree/master/middleware/compress) | Compression middleware for Fiber, it supports `deflate`, `gzip` and `brotli` by default. |
@ -499,9 +499,9 @@ Here is a list of middleware that are included within the Fiber framework.
## 🧬 External Middleware
List of externally hosted middleware modules and maintained by the [Fiber team](https://github.com/orgs/gofiber/people).
Harici olarak barındırılan middlewareların modüllerinin listesi [Fiber ekibi](https://github.com/orgs/gofiber/people) tarafından korunur.
| Middleware | Description |
| Middleware | ıklama |
| :------------------------------------------------ | :------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| [adaptor](https://github.com/gofiber/adaptor) | Converter for net/http handlers to/from Fiber request handlers, special thanks to @arsmn! |
| [helmet](https://github.com/gofiber/helmet) | Helps secure your apps by setting various HTTP headers. |
@ -514,7 +514,7 @@ List of externally hosted middleware modules and maintained by the [Fiber team](
## 🌱 Third Party Middlewares
This is a list of middlewares that are created by the Fiber community, please create a PR if you want to see yours!
Bu, Fiber topluluğu tarafından oluşturulan middlewareların bir listesidir, sizinkini görmek istiyorsanız lütfen bir PR oluşturun!
- [arsmn/fiber-casbin](https://github.com/arsmn/fiber-casbin)
- [arsmn/fiber-introspect](https://github.com/arsmn/fiber-introspect)
@ -571,7 +571,7 @@ Fiber, alan adı, gitbook, netlify, serverless yer sağlayıcısı giderleri ve
<img src="https://opencollective.com/fiber/contributors.svg?width=890&button=false" alt="Code Contributors" style="max-width:100%;">
## ⭐️ Stargazers
## ⭐️ Projeyi Yıldızlayanlar
<img src="https://starchart.cc/gofiber/fiber.svg" alt="Stargazers over time" style="max-width: 100%">

View File

@ -152,7 +152,7 @@ func main() {
})
// GET /john/75
app.Get("/:name/:age/:gender?", func(c *fiber.Ctx) error {
app.Get("/:name/:age", func(c *fiber.Ctx) error {
msg := fmt.Sprintf("👴 %s is %s years old", c.Params("name"), c.Params("age"))
return c.SendString(msg) // => 👴 john is 75 years old
})

View File

@ -155,7 +155,7 @@ func main() {
})
// GET /john/75
app.Get("/:name/:age/:gender?", func(c *fiber.Ctx) error {
app.Get("/:name/:age", func(c *fiber.Ctx) error {
msg := fmt.Sprintf("👴 %s is %s years old", c.Params("name"), c.Params("age"))
return c.SendString(msg) // => 👴 john is 75 years old
})

46
app.go
View File

@ -568,7 +568,7 @@ func (app *App) Listener(ln net.Listener) error {
app.startupProcess()
// Print startup message
if !app.config.DisableStartupMessage {
app.startupMessage(ln.Addr().String(), false, "")
app.startupMessage(ln.Addr().String(), getTlsConfig(ln) != nil, "")
}
// Start listening
return app.server.Serve(ln)
@ -817,17 +817,6 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
return
}
var logo string
logo += "%s"
logo += " ┌───────────────────────────────────────────────────┐\n"
logo += " │ %s │\n"
logo += " │ %s │\n"
logo += " │ │\n"
logo += " │ Handlers %s Processes %s │\n"
logo += " │ Prefork .%s PID ....%s │\n"
logo += " └───────────────────────────────────────────────────┘"
logo += "%s"
const (
cBlack = "\u001b[90m"
// cRed = "\u001b[91m"
@ -886,12 +875,13 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
}
host, port := parseAddr(addr)
if host == "" || host == "0.0.0.0" {
host = "127.0.0.1"
if host == "" {
host = "0.0.0.0"
}
addr = "http://" + host + ":" + port
scheme := "http"
if tls {
addr = "https://" + host + ":" + port
scheme = "https"
}
isPrefork := "Disabled"
@ -904,13 +894,27 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
procs = "1"
}
mainLogo := fmt.Sprintf(logo,
cBlack,
centerValue(" Fiber v"+Version, 49),
center(addr, 49),
mainLogo := cBlack+
" ┌───────────────────────────────────────────────────┐\n"+
" │ "+centerValue(" Fiber v"+Version, 49)+" │\n"
if host == "0.0.0.0" {
mainLogo +=
" │ "+center(fmt.Sprintf("%s://127.0.0.1:%s", scheme, port), 49)+ " │\n" +
" │ "+center(fmt.Sprintf("(bound on host 0.0.0.0 and port %s)", port), 49)+ " │\n"
} else {
mainLogo +=
" │ "+center(fmt.Sprintf("%s://%s:%s", scheme, host, port), 49)+ " │\n"
}
mainLogo += fmt.Sprintf(
" │ │\n"+
" │ Handlers %s Processes %s │\n"+
" │ Prefork .%s PID ....%s │\n"+
" └───────────────────────────────────────────────────┘"+
cReset,
value(strconv.Itoa(app.handlerCount), 14), value(procs, 12),
value(isPrefork, 14), value(strconv.Itoa(os.Getpid()), 14),
cReset,
)
var childPidsLogo string

View File

@ -382,29 +382,6 @@ func Test_App_Add_Method_Test(t *testing.T) {
app.Add("JOHN", "/doe", testEmptyHandler)
}
func Test_App_Listener_TLS(t *testing.T) {
app := New()
// Create tls certificate
cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key")
if err != nil {
utils.AssertEqual(t, nil, err)
}
config := &tls.Config{Certificates: []tls.Certificate{cer}}
ln, err := net.Listen(NetworkTCP4, ":3078")
utils.AssertEqual(t, nil, err)
ln = tls.NewListener(ln, config)
go func() {
time.Sleep(1000 * time.Millisecond)
utils.AssertEqual(t, nil, app.Shutdown())
}()
utils.AssertEqual(t, nil, app.Listener(ln))
}
// go test -run Test_App_GETOnly
func Test_App_GETOnly(t *testing.T) {
app := New(Config{
@ -1011,6 +988,27 @@ func Test_App_Listener_Prefork(t *testing.T) {
utils.AssertEqual(t, nil, app.Listener(ln))
}
func Test_App_Listener_TLS(t *testing.T) {
// Create tls certificate
cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key")
if err != nil {
utils.AssertEqual(t, nil, err)
}
config := &tls.Config{Certificates: []tls.Certificate{cer}}
ln, err := tls.Listen(NetworkTCP4, ":0", config)
utils.AssertEqual(t, nil, err)
app := New()
go func() {
time.Sleep(time.Millisecond * 500)
utils.AssertEqual(t, nil, app.Shutdown())
}()
utils.AssertEqual(t, nil, app.Listener(ln))
}
// go test -v -run=^$ -bench=Benchmark_AcquireCtx -benchmem -count=4
func Benchmark_AcquireCtx(b *testing.B) {
app := New()

985
client.go Normal file
View File

@ -0,0 +1,985 @@
package fiber
import (
"bytes"
"crypto/tls"
"encoding/xml"
"fmt"
"io"
"io/ioutil"
"mime/multipart"
"net"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/gofiber/fiber/v2/utils"
"github.com/gofiber/fiber/v2/internal/encoding/json"
"github.com/valyala/fasthttp"
)
// Request represents HTTP request.
//
// It is forbidden copying Request instances. Create new instances
// and use CopyTo instead.
//
// Request instance MUST NOT be used from concurrently running goroutines.
// Copy from fasthttp
type Request = fasthttp.Request
// Response represents HTTP response.
//
// It is forbidden copying Response instances. Create new instances
// and use CopyTo instead.
//
// Response instance MUST NOT be used from concurrently running goroutines.
// Copy from fasthttp
type Response = fasthttp.Response
// Args represents query arguments.
//
// It is forbidden copying Args instances. Create new instances instead
// and use CopyTo().
//
// Args instance MUST NOT be used from concurrently running goroutines.
// Copy from fasthttp
type Args = fasthttp.Args
var defaultClient Client
// Client implements http client.
//
// It is safe calling Client methods from concurrently running goroutines.
type Client struct {
// UserAgent is used in User-Agent request header.
UserAgent string
// NoDefaultUserAgentHeader when set to true, causes the default
// User-Agent header to be excluded from the Request.
NoDefaultUserAgentHeader bool
// When set by an external client of Fiber it will use the provided implementation of a
// JSONMarshal
//
// Allowing for flexibility in using another json library for encoding
JSONEncoder utils.JSONMarshal
// When set by an external client of Fiber it will use the provided implementation of a
// JSONUnmarshal
//
// Allowing for flexibility in using another json library for decoding
JSONDecoder utils.JSONUnmarshal
}
// Get returns a agent with http method GET.
func Get(url string) *Agent { return defaultClient.Get(url) }
// Get returns a agent with http method GET.
func (c *Client) Get(url string) *Agent {
return c.createAgent(MethodGet, url)
}
// Head returns a agent with http method HEAD.
func Head(url string) *Agent { return defaultClient.Head(url) }
// Head returns a agent with http method GET.
func (c *Client) Head(url string) *Agent {
return c.createAgent(MethodHead, url)
}
// Post sends POST request to the given url.
func Post(url string) *Agent { return defaultClient.Post(url) }
// Post sends POST request to the given url.
func (c *Client) Post(url string) *Agent {
return c.createAgent(MethodPost, url)
}
// Put sends PUT request to the given url.
func Put(url string) *Agent { return defaultClient.Put(url) }
// Put sends PUT request to the given url.
func (c *Client) Put(url string) *Agent {
return c.createAgent(MethodPut, url)
}
// Patch sends PATCH request to the given url.
func Patch(url string) *Agent { return defaultClient.Patch(url) }
// Patch sends PATCH request to the given url.
func (c *Client) Patch(url string) *Agent {
return c.createAgent(MethodPatch, url)
}
// Delete sends DELETE request to the given url.
func Delete(url string) *Agent { return defaultClient.Delete(url) }
// Delete sends DELETE request to the given url.
func (c *Client) Delete(url string) *Agent {
return c.createAgent(MethodDelete, url)
}
func (c *Client) createAgent(method, url string) *Agent {
a := AcquireAgent()
a.req.Header.SetMethod(method)
a.req.SetRequestURI(url)
a.Name = c.UserAgent
a.NoDefaultUserAgentHeader = c.NoDefaultUserAgentHeader
a.jsonDecoder = c.JSONDecoder
a.jsonEncoder = c.JSONEncoder
if err := a.Parse(); err != nil {
a.errs = append(a.errs, err)
}
return a
}
// Agent is an object storing all request data for client.
// Agent instance MUST NOT be used from concurrently running goroutines.
type Agent struct {
// Name is used in User-Agent request header.
Name string
// NoDefaultUserAgentHeader when set to true, causes the default
// User-Agent header to be excluded from the Request.
NoDefaultUserAgentHeader bool
// HostClient is an embedded fasthttp HostClient
*fasthttp.HostClient
req *Request
resp *Response
dest []byte
args *Args
timeout time.Duration
errs []error
formFiles []*FormFile
debugWriter io.Writer
mw multipartWriter
jsonEncoder utils.JSONMarshal
jsonDecoder utils.JSONUnmarshal
maxRedirectsCount int
boundary string
reuse bool
parsed bool
}
// Parse initializes URI and HostClient.
func (a *Agent) Parse() error {
if a.parsed {
return nil
}
a.parsed = true
uri := a.req.URI()
isTLS := false
scheme := uri.Scheme()
if bytes.Equal(scheme, strHTTPS) {
isTLS = true
} else if !bytes.Equal(scheme, strHTTP) {
return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme)
}
name := a.Name
if name == "" && !a.NoDefaultUserAgentHeader {
name = defaultUserAgent
}
a.HostClient = &fasthttp.HostClient{
Addr: addMissingPort(string(uri.Host()), isTLS),
Name: name,
NoDefaultUserAgentHeader: a.NoDefaultUserAgentHeader,
IsTLS: isTLS,
}
return nil
}
func addMissingPort(addr string, isTLS bool) string {
n := strings.Index(addr, ":")
if n >= 0 {
return addr
}
port := 80
if isTLS {
port = 443
}
return net.JoinHostPort(addr, strconv.Itoa(port))
}
/************************** Header Setting **************************/
// Set sets the given 'key: value' header.
//
// Use Add for setting multiple header values under the same key.
func (a *Agent) Set(k, v string) *Agent {
a.req.Header.Set(k, v)
return a
}
// SetBytesK sets the given 'key: value' header.
//
// Use AddBytesK for setting multiple header values under the same key.
func (a *Agent) SetBytesK(k []byte, v string) *Agent {
a.req.Header.SetBytesK(k, v)
return a
}
// SetBytesV sets the given 'key: value' header.
//
// Use AddBytesV for setting multiple header values under the same key.
func (a *Agent) SetBytesV(k string, v []byte) *Agent {
a.req.Header.SetBytesV(k, v)
return a
}
// SetBytesKV sets the given 'key: value' header.
//
// Use AddBytesKV for setting multiple header values under the same key.
func (a *Agent) SetBytesKV(k []byte, v []byte) *Agent {
a.req.Header.SetBytesKV(k, v)
return a
}
// Add adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use Set for setting a single header for the given key.
func (a *Agent) Add(k, v string) *Agent {
a.req.Header.Add(k, v)
return a
}
// AddBytesK adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesK for setting a single header for the given key.
func (a *Agent) AddBytesK(k []byte, v string) *Agent {
a.req.Header.AddBytesK(k, v)
return a
}
// AddBytesV adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesV for setting a single header for the given key.
func (a *Agent) AddBytesV(k string, v []byte) *Agent {
a.req.Header.AddBytesV(k, v)
return a
}
// AddBytesKV adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesKV for setting a single header for the given key.
func (a *Agent) AddBytesKV(k []byte, v []byte) *Agent {
a.req.Header.AddBytesKV(k, v)
return a
}
// ConnectionClose sets 'Connection: close' header.
func (a *Agent) ConnectionClose() *Agent {
a.req.Header.SetConnectionClose()
return a
}
// UserAgent sets User-Agent header value.
func (a *Agent) UserAgent(userAgent string) *Agent {
a.req.Header.SetUserAgent(userAgent)
return a
}
// UserAgentBytes sets User-Agent header value.
func (a *Agent) UserAgentBytes(userAgent []byte) *Agent {
a.req.Header.SetUserAgentBytes(userAgent)
return a
}
// Cookie sets one 'key: value' cookie.
func (a *Agent) Cookie(key, value string) *Agent {
a.req.Header.SetCookie(key, value)
return a
}
// CookieBytesK sets one 'key: value' cookie.
func (a *Agent) CookieBytesK(key []byte, value string) *Agent {
a.req.Header.SetCookieBytesK(key, value)
return a
}
// CookieBytesKV sets one 'key: value' cookie.
func (a *Agent) CookieBytesKV(key, value []byte) *Agent {
a.req.Header.SetCookieBytesKV(key, value)
return a
}
// Cookies sets multiple 'key: value' cookies.
func (a *Agent) Cookies(kv ...string) *Agent {
for i := 1; i < len(kv); i += 2 {
a.req.Header.SetCookie(kv[i-1], kv[i])
}
return a
}
// CookiesBytesKV sets multiple 'key: value' cookies.
func (a *Agent) CookiesBytesKV(kv ...[]byte) *Agent {
for i := 1; i < len(kv); i += 2 {
a.req.Header.SetCookieBytesKV(kv[i-1], kv[i])
}
return a
}
// Referer sets Referer header value.
func (a *Agent) Referer(referer string) *Agent {
a.req.Header.SetReferer(referer)
return a
}
// RefererBytes sets Referer header value.
func (a *Agent) RefererBytes(referer []byte) *Agent {
a.req.Header.SetRefererBytes(referer)
return a
}
// ContentType sets Content-Type header value.
func (a *Agent) ContentType(contentType string) *Agent {
a.req.Header.SetContentType(contentType)
return a
}
// ContentTypeBytes sets Content-Type header value.
func (a *Agent) ContentTypeBytes(contentType []byte) *Agent {
a.req.Header.SetContentTypeBytes(contentType)
return a
}
/************************** End Header Setting **************************/
/************************** URI Setting **************************/
// Host sets host for the uri.
func (a *Agent) Host(host string) *Agent {
a.req.URI().SetHost(host)
return a
}
// HostBytes sets host for the URI.
func (a *Agent) HostBytes(host []byte) *Agent {
a.req.URI().SetHostBytes(host)
return a
}
// QueryString sets URI query string.
func (a *Agent) QueryString(queryString string) *Agent {
a.req.URI().SetQueryString(queryString)
return a
}
// QueryStringBytes sets URI query string.
func (a *Agent) QueryStringBytes(queryString []byte) *Agent {
a.req.URI().SetQueryStringBytes(queryString)
return a
}
// BasicAuth sets URI username and password.
func (a *Agent) BasicAuth(username, password string) *Agent {
a.req.URI().SetUsername(username)
a.req.URI().SetPassword(password)
return a
}
// BasicAuthBytes sets URI username and password.
func (a *Agent) BasicAuthBytes(username, password []byte) *Agent {
a.req.URI().SetUsernameBytes(username)
a.req.URI().SetPasswordBytes(password)
return a
}
/************************** End URI Setting **************************/
/************************** Request Setting **************************/
// BodyString sets request body.
func (a *Agent) BodyString(bodyString string) *Agent {
a.req.SetBodyString(bodyString)
return a
}
// Body sets request body.
func (a *Agent) Body(body []byte) *Agent {
a.req.SetBody(body)
return a
}
// BodyStream sets request body stream and, optionally body size.
//
// If bodySize is >= 0, then the bodyStream must provide exactly bodySize bytes
// before returning io.EOF.
//
// If bodySize < 0, then bodyStream is read until io.EOF.
//
// bodyStream.Close() is called after finishing reading all body data
// if it implements io.Closer.
//
// Note that GET and HEAD requests cannot have body.
func (a *Agent) BodyStream(bodyStream io.Reader, bodySize int) *Agent {
a.req.SetBodyStream(bodyStream, bodySize)
return a
}
// JSON sends a JSON request.
func (a *Agent) JSON(v interface{}) *Agent {
if a.jsonEncoder == nil {
a.jsonEncoder = json.Marshal
}
a.req.Header.SetContentType(MIMEApplicationJSON)
if body, err := a.jsonEncoder(v); err != nil {
a.errs = append(a.errs, err)
} else {
a.req.SetBody(body)
}
return a
}
// XML sends a XML request.
func (a *Agent) XML(v interface{}) *Agent {
a.req.Header.SetContentType(MIMEApplicationXML)
if body, err := xml.Marshal(v); err != nil {
a.errs = append(a.errs, err)
} else {
a.req.SetBody(body)
}
return a
}
// Form sends form request with body if args is non-nil.
//
// It is recommended obtaining args via AcquireArgs and release it
// manually in performance-critical code.
func (a *Agent) Form(args *Args) *Agent {
a.req.Header.SetContentType(MIMEApplicationForm)
if args != nil {
a.req.SetBody(args.QueryString())
}
return a
}
// FormFile represents multipart form file
type FormFile struct {
// Fieldname is form file's field name
Fieldname string
// Name is form file's name
Name string
// Content is form file's content
Content []byte
// autoRelease indicates if returns the object
// acquired via AcquireFormFile to the pool.
autoRelease bool
}
// FileData appends files for multipart form request.
//
// It is recommended obtaining formFile via AcquireFormFile and release it
// manually in performance-critical code.
func (a *Agent) FileData(formFiles ...*FormFile) *Agent {
a.formFiles = append(a.formFiles, formFiles...)
return a
}
// SendFile reads file and appends it to multipart form request.
func (a *Agent) SendFile(filename string, fieldname ...string) *Agent {
content, err := ioutil.ReadFile(filepath.Clean(filename))
if err != nil {
a.errs = append(a.errs, err)
return a
}
ff := AcquireFormFile()
if len(fieldname) > 0 && fieldname[0] != "" {
ff.Fieldname = fieldname[0]
} else {
ff.Fieldname = "file" + strconv.Itoa(len(a.formFiles)+1)
}
ff.Name = filepath.Base(filename)
ff.Content = append(ff.Content, content...)
ff.autoRelease = true
a.formFiles = append(a.formFiles, ff)
return a
}
// SendFiles reads files and appends them to multipart form request.
//
// Examples:
// SendFile("/path/to/file1", "fieldname1", "/path/to/file2")
func (a *Agent) SendFiles(filenamesAndFieldnames ...string) *Agent {
pairs := len(filenamesAndFieldnames)
if pairs&1 == 1 {
filenamesAndFieldnames = append(filenamesAndFieldnames, "")
}
for i := 0; i < pairs; i += 2 {
a.SendFile(filenamesAndFieldnames[i], filenamesAndFieldnames[i+1])
}
return a
}
// Boundary sets boundary for multipart form request.
func (a *Agent) Boundary(boundary string) *Agent {
a.boundary = boundary
return a
}
// MultipartForm sends multipart form request with k-v and files.
//
// It is recommended obtaining args via AcquireArgs and release it
// manually in performance-critical code.
func (a *Agent) MultipartForm(args *Args) *Agent {
if a.mw == nil {
a.mw = multipart.NewWriter(a.req.BodyWriter())
}
if a.boundary != "" {
if err := a.mw.SetBoundary(a.boundary); err != nil {
a.errs = append(a.errs, err)
return a
}
}
a.req.Header.SetMultipartFormBoundary(a.mw.Boundary())
if args != nil {
args.VisitAll(func(key, value []byte) {
if err := a.mw.WriteField(getString(key), getString(value)); err != nil {
a.errs = append(a.errs, err)
}
})
}
for _, ff := range a.formFiles {
w, err := a.mw.CreateFormFile(ff.Fieldname, ff.Name)
if err != nil {
a.errs = append(a.errs, err)
continue
}
if _, err = w.Write(ff.Content); err != nil {
a.errs = append(a.errs, err)
}
}
if err := a.mw.Close(); err != nil {
a.errs = append(a.errs, err)
}
return a
}
/************************** End Request Setting **************************/
/************************** Agent Setting **************************/
// Debug mode enables logging request and response detail
func (a *Agent) Debug(w ...io.Writer) *Agent {
a.debugWriter = os.Stdout
if len(w) > 0 {
a.debugWriter = w[0]
}
return a
}
// Timeout sets request timeout duration.
func (a *Agent) Timeout(timeout time.Duration) *Agent {
a.timeout = timeout
return a
}
// Reuse enables the Agent instance to be used again after one request.
//
// If agent is reusable, then it should be released manually when it is no
// longer used.
func (a *Agent) Reuse() *Agent {
a.reuse = true
return a
}
// InsecureSkipVerify controls whether the Agent verifies the server
// certificate chain and host name.
func (a *Agent) InsecureSkipVerify() *Agent {
if a.HostClient.TLSConfig == nil {
/* #nosec G402 */
a.HostClient.TLSConfig = &tls.Config{InsecureSkipVerify: true}
} else {
/* #nosec G402 */
a.HostClient.TLSConfig.InsecureSkipVerify = true
}
return a
}
// TLSConfig sets tls config.
func (a *Agent) TLSConfig(config *tls.Config) *Agent {
a.HostClient.TLSConfig = config
return a
}
// MaxRedirectsCount sets max redirect count for GET and HEAD.
func (a *Agent) MaxRedirectsCount(count int) *Agent {
a.maxRedirectsCount = count
return a
}
// JSONEncoder sets custom json encoder.
func (a *Agent) JSONEncoder(jsonEncoder utils.JSONMarshal) *Agent {
a.jsonEncoder = jsonEncoder
return a
}
// JSONDecoder sets custom json decoder.
func (a *Agent) JSONDecoder(jsonDecoder utils.JSONUnmarshal) *Agent {
a.jsonDecoder = jsonDecoder
return a
}
// Request returns Agent request instance.
func (a *Agent) Request() *Request {
return a.req
}
// SetResponse sets custom response for the Agent instance.
//
// It is recommended obtaining custom response via AcquireResponse and release it
// manually in performance-critical code.
func (a *Agent) SetResponse(customResp *Response) *Agent {
a.resp = customResp
return a
}
// Dest sets custom dest.
//
// The contents of dest will be replaced by the response body, if the dest
// is too small a new slice will be allocated.
func (a *Agent) Dest(dest []byte) *Agent {
a.dest = dest
return a
}
/************************** End Agent Setting **************************/
// Bytes returns the status code, bytes body and errors of url.
func (a *Agent) Bytes() (code int, body []byte, errs []error) {
fmt.Println("[Warning] client is still in beta, API might change in the future!")
defer a.release()
if errs = append(errs, a.errs...); len(errs) > 0 {
return
}
var (
req = a.req
resp *Response
nilResp bool
)
if a.resp == nil {
resp = AcquireResponse()
nilResp = true
} else {
resp = a.resp
}
defer func() {
if a.debugWriter != nil {
printDebugInfo(req, resp, a.debugWriter)
}
if len(errs) == 0 {
code = resp.StatusCode()
}
body = append(a.dest, resp.Body()...)
if nilResp {
ReleaseResponse(resp)
}
}()
if a.timeout > 0 {
if err := a.HostClient.DoTimeout(req, resp, a.timeout); err != nil {
errs = append(errs, err)
return
}
}
if a.maxRedirectsCount > 0 && (string(req.Header.Method()) == MethodGet || string(req.Header.Method()) == MethodHead) {
if err := a.HostClient.DoRedirects(req, resp, a.maxRedirectsCount); err != nil {
errs = append(errs, err)
return
}
}
if err := a.HostClient.Do(req, resp); err != nil {
errs = append(errs, err)
}
return
}
func printDebugInfo(req *Request, resp *Response, w io.Writer) {
msg := fmt.Sprintf("Connected to %s(%s)\r\n\r\n", req.URI().Host(), resp.RemoteAddr())
_, _ = w.Write(getBytes(msg))
_, _ = req.WriteTo(w)
_, _ = resp.WriteTo(w)
}
// String returns the status code, string body and errors of url.
func (a *Agent) String() (int, string, []error) {
code, body, errs := a.Bytes()
return code, getString(body), errs
}
// Struct returns the status code, bytes body and errors of url.
// And bytes body will be unmarshalled to given v.
func (a *Agent) Struct(v interface{}) (code int, body []byte, errs []error) {
if a.jsonDecoder == nil {
a.jsonDecoder = json.Unmarshal
}
if code, body, errs = a.Bytes(); len(errs) > 0 {
return
}
if err := a.jsonDecoder(body, v); err != nil {
errs = append(errs, err)
}
return
}
func (a *Agent) release() {
if !a.reuse {
ReleaseAgent(a)
} else {
a.errs = a.errs[:0]
}
}
func (a *Agent) reset() {
a.HostClient = nil
a.req.Reset()
a.resp = nil
a.dest = nil
a.timeout = 0
a.args = nil
a.errs = a.errs[:0]
a.debugWriter = nil
a.mw = nil
a.reuse = false
a.parsed = false
a.maxRedirectsCount = 0
a.boundary = ""
a.Name = ""
a.NoDefaultUserAgentHeader = false
for i, ff := range a.formFiles {
if ff.autoRelease {
ReleaseFormFile(ff)
}
a.formFiles[i] = nil
}
a.formFiles = a.formFiles[:0]
}
var (
clientPool sync.Pool
agentPool sync.Pool
responsePool sync.Pool
argsPool sync.Pool
formFilePool sync.Pool
)
// AcquireClient returns an empty Client instance from client pool.
//
// The returned Client instance may be passed to ReleaseClient when it is
// no longer needed. This allows Client recycling, reduces GC pressure
// and usually improves performance.
func AcquireClient() *Client {
v := clientPool.Get()
if v == nil {
return &Client{}
}
return v.(*Client)
}
// ReleaseClient returns c acquired via AcquireClient to client pool.
//
// It is forbidden accessing req and/or its' members after returning
// it to client pool.
func ReleaseClient(c *Client) {
c.UserAgent = ""
c.NoDefaultUserAgentHeader = false
clientPool.Put(c)
}
// AcquireAgent returns an empty Agent instance from Agent pool.
//
// The returned Agent instance may be passed to ReleaseAgent when it is
// no longer needed. This allows Agent recycling, reduces GC pressure
// and usually improves performance.
func AcquireAgent() *Agent {
v := agentPool.Get()
if v == nil {
return &Agent{req: &Request{}}
}
return v.(*Agent)
}
// ReleaseAgent returns a acquired via AcquireAgent to Agent pool.
//
// It is forbidden accessing req and/or its' members after returning
// it to Agent pool.
func ReleaseAgent(a *Agent) {
a.reset()
agentPool.Put(a)
}
// AcquireResponse returns an empty Response instance from response pool.
//
// The returned Response instance may be passed to ReleaseResponse when it is
// no longer needed. This allows Response recycling, reduces GC pressure
// and usually improves performance.
// Copy from fasthttp
func AcquireResponse() *Response {
v := responsePool.Get()
if v == nil {
return &Response{}
}
return v.(*Response)
}
// ReleaseResponse return resp acquired via AcquireResponse to response pool.
//
// It is forbidden accessing resp and/or its' members after returning
// it to response pool.
// Copy from fasthttp
func ReleaseResponse(resp *Response) {
resp.Reset()
responsePool.Put(resp)
}
// AcquireArgs returns an empty Args object from the pool.
//
// The returned Args may be returned to the pool with ReleaseArgs
// when no longer needed. This allows reducing GC load.
// Copy from fasthttp
func AcquireArgs() *Args {
v := argsPool.Get()
if v == nil {
return &Args{}
}
return v.(*Args)
}
// ReleaseArgs returns the object acquired via AcquireArgs to the pool.
//
// String not access the released Args object, otherwise data races may occur.
// Copy from fasthttp
func ReleaseArgs(a *Args) {
a.Reset()
argsPool.Put(a)
}
// AcquireFormFile returns an empty FormFile object from the pool.
//
// The returned FormFile may be returned to the pool with ReleaseFormFile
// when no longer needed. This allows reducing GC load.
func AcquireFormFile() *FormFile {
v := formFilePool.Get()
if v == nil {
return &FormFile{}
}
return v.(*FormFile)
}
// ReleaseFormFile returns the object acquired via AcquireFormFile to the pool.
//
// String not access the released FormFile object, otherwise data races may occur.
func ReleaseFormFile(ff *FormFile) {
ff.Fieldname = ""
ff.Name = ""
ff.Content = ff.Content[:0]
ff.autoRelease = false
formFilePool.Put(ff)
}
var (
strHTTP = []byte("http")
strHTTPS = []byte("https")
defaultUserAgent = "fiber"
)
type multipartWriter interface {
Boundary() string
SetBoundary(boundary string) error
CreateFormFile(fieldname, filename string) (io.Writer, error)
WriteField(fieldname, value string) error
Close() error
}

1086
client_test.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -49,6 +49,14 @@ func lnMetadata(network string, ln net.Listener) (addr string, cfg *tls.Config)
panic("listener: " + addr + ": Only one usage of each socket address (protocol/network address/port) is normally permitted.")
}
cfg = getTlsConfig(ln)
return
}
/* #nosec */
// getTlsConfig returns a net listener's tls config
func getTlsConfig(ln net.Listener) *tls.Config {
// Get listener type
pointer := reflect.ValueOf(ln)
@ -63,13 +71,14 @@ func lnMetadata(network string, ln net.Listener) (addr string, cfg *tls.Config)
// Get element from pointer
if elem := newval.Elem(); elem.Type() != nil {
// Cast value to *tls.Config
cfg = elem.Interface().(*tls.Config)
return elem.Interface().(*tls.Config)
}
}
}
}
}
return
return nil
}
// readContent opens a named file and read content from it

View File

@ -1,7 +1,7 @@
# CSRF Middleware
CSRF middleware for [Fiber](https://github.com/gofiber/fiber) that provides [Cross-site request forgery](https://en.wikipedia.org/wiki/Cross-site_request_forgery) protection by passing a csrf token via cookies. This cookie value will be used to compare against the client csrf token in POST requests. When the csrf token is invalid, this middleware will delete the `_csrf` cookie and return the `fiber.ErrForbidden` error.
CSRF Tokens are generated on GET requests.
CSRF middleware for [Fiber](https://github.com/gofiber/fiber) that provides [Cross-site request forgery](https://en.wikipedia.org/wiki/Cross-site_request_forgery) protection by passing a csrf token via cookies. This cookie value will be used to compare against the client csrf token in POST requests. When the csrf token is invalid, this middleware will delete the `csrf_` cookie and return the `fiber.ErrForbidden` error.
CSRF Tokens are generated on GET requests. You can retrieve the CSRF token with `c.Locals(contextKey)`, where `contextKey` is the string you set in the config (see Custom Config below).
_NOTE: This middleware uses our [Storage](https://github.com/gofiber/storage) package to support various databases through a single interface. The default configuration for this middleware saves data to memory, see the examples below for other databases._

View File

@ -2,6 +2,8 @@ package csrf
import (
"fmt"
"net/textproto"
"strings"
"time"
"github.com/gofiber/fiber/v2"
@ -85,6 +87,9 @@ type Config struct {
//
// Optional. Default: DefaultErrorHandler
ErrorHandler fiber.ErrorHandler
// extractor returns the csrf token from the request based on KeyLookup
extractor func(c *fiber.Ctx) (string, error)
}
// ConfigDefault is the default config
@ -95,6 +100,7 @@ var ConfigDefault = Config{
Expiration: 1 * time.Hour,
KeyGenerator: utils.UUID,
ErrorHandler: defaultErrorHandler,
extractor: csrfFromHeader("X-Csrf-Token"),
}
// default ErrorHandler that process return error from fiber.Handler
@ -157,5 +163,26 @@ func configDefault(config ...Config) Config {
cfg.ErrorHandler = ConfigDefault.ErrorHandler
}
// Generate the correct extractor to get the token from the correct location
selectors := strings.Split(cfg.KeyLookup, ":")
if len(selectors) != 2 {
panic("[CSRF] KeyLookup must in the form of <source>:<key>")
}
// By default we extract from a header
cfg.extractor = csrfFromHeader(textproto.CanonicalMIMEHeaderKey(selectors[1]))
switch selectors[0] {
case "form":
cfg.extractor = csrfFromForm(selectors[1])
case "query":
cfg.extractor = csrfFromQuery(selectors[1])
case "param":
cfg.extractor = csrfFromParam(selectors[1])
case "cookie":
cfg.extractor = csrfFromCookie(selectors[1])
}
return cfg
}

View File

@ -1,9 +1,6 @@
package csrf
import (
"errors"
"net/textproto"
"strings"
"time"
"github.com/gofiber/fiber/v2"
@ -17,27 +14,6 @@ func New(config ...Config) fiber.Handler {
// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage)
// Generate the correct extractor to get the token from the correct location
selectors := strings.Split(cfg.KeyLookup, ":")
if len(selectors) != 2 {
panic("[CSRF] KeyLookup must in the form of <source>:<key>")
}
// By default we extract from a header
extractor := csrfFromHeader(textproto.CanonicalMIMEHeaderKey(selectors[1]))
switch selectors[0] {
case "form":
extractor = csrfFromForm(selectors[1])
case "query":
extractor = csrfFromQuery(selectors[1])
case "param":
extractor = csrfFromParam(selectors[1])
case "cookie":
extractor = csrfFromCookie(selectors[1])
}
dummyValue := []byte{'+'}
// Return new handler
@ -51,43 +27,20 @@ func New(config ...Config) fiber.Handler {
// Action depends on the HTTP method
switch c.Method() {
case fiber.MethodGet:
case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace:
// Declare empty token and try to get existing CSRF from cookie
token = c.Cookies(cfg.CookieName)
default:
// Assume that anything not defined as 'safe' by RFC7231 needs protection
// Generate CSRF token if not exist
if token == "" {
// Generate new CSRF token
token = cfg.KeyGenerator()
// Add token to Storage
manager.setRaw(token, dummyValue, cfg.Expiration)
}
// Create cookie to pass token to client
cookie := &fiber.Cookie{
Name: cfg.CookieName,
Value: token,
Domain: cfg.CookieDomain,
Path: cfg.CookiePath,
Expires: time.Now().Add(cfg.Expiration),
Secure: cfg.CookieSecure,
HTTPOnly: cfg.CookieHTTPOnly,
SameSite: cfg.CookieSameSite,
}
// Set cookie to response
c.Cookie(cookie)
case fiber.MethodPost, fiber.MethodDelete, fiber.MethodPatch, fiber.MethodPut:
// Extract token from client request i.e. header, query, param, form or cookie
token, err = extractor(c)
token, err = cfg.extractor(c)
if err != nil {
return cfg.ErrorHandler(c, err)
}
// if token does not exist in Storage
if manager.getRaw(token) == nil {
// Expire cookie
c.Cookie(&fiber.Cookie{
Name: cfg.CookieName,
@ -98,14 +51,33 @@ func New(config ...Config) fiber.Handler {
HTTPOnly: cfg.CookieHTTPOnly,
SameSite: cfg.CookieSameSite,
})
return cfg.ErrorHandler(c, err)
}
// The token is validated, time to delete it
manager.delete(token)
}
// Generate CSRF token if not exist
if token == "" {
// And generate a new token
token = cfg.KeyGenerator()
}
// Add/update token to Storage
manager.setRaw(token, dummyValue, cfg.Expiration)
// Create cookie to pass token to client
cookie := &fiber.Cookie{
Name: cfg.CookieName,
Value: token,
Domain: cfg.CookieDomain,
Path: cfg.CookiePath,
Expires: time.Now().Add(cfg.Expiration),
Secure: cfg.CookieSecure,
HTTPOnly: cfg.CookieHTTPOnly,
SameSite: cfg.CookieSameSite,
}
// Set cookie to response
c.Cookie(cookie)
// Protect clients from caching the response by telling the browser
// a new header value is generated
c.Vary(fiber.HeaderCookie)
@ -119,66 +91,3 @@ func New(config ...Config) fiber.Handler {
return c.Next()
}
}
var (
errMissingHeader = errors.New("missing csrf token in header")
errMissingQuery = errors.New("missing csrf token in query")
errMissingParam = errors.New("missing csrf token in param")
errMissingForm = errors.New("missing csrf token in form")
errMissingCookie = errors.New("missing csrf token in cookie")
)
// csrfFromHeader returns a function that extracts token from the request header.
func csrfFromHeader(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Get(param)
if token == "" {
return "", errMissingHeader
}
return token, nil
}
}
// csrfFromQuery returns a function that extracts token from the query string.
func csrfFromQuery(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Query(param)
if token == "" {
return "", errMissingQuery
}
return token, nil
}
}
// csrfFromParam returns a function that extracts token from the url param string.
func csrfFromParam(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Params(param)
if token == "" {
return "", errMissingParam
}
return token, nil
}
}
// csrfFromForm returns a function that extracts a token from a multipart-form.
func csrfFromForm(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.FormValue(param)
if token == "" {
return "", errMissingForm
}
return token, nil
}
}
// csrfFromCookie returns a function that extracts token from the cookie header.
func csrfFromCookie(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Cookies(param)
if token == "" {
return "", errMissingCookie
}
return token, nil
}
}

View File

@ -22,41 +22,45 @@ func Test_CSRF(t *testing.T) {
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Header.SetMethod("GET")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
methods := [4]string{"GET", "HEAD", "OPTIONS", "TRACE"}
// Without CSRF cookie
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST")
h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
for _, method := range methods {
// Generate CSRF token
ctx.Request.Header.SetMethod(method)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Empty/invalid CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set("X-CSRF-Token", "johndoe")
h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
// Without CSRF cookie
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST")
h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
// Valid CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("GET")
h(ctx)
token = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Empty/invalid CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set("X-CSRF-Token", "johndoe")
h(ctx)
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set("X-CSRF-Token", token)
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
// Valid CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(method)
h(ctx)
token = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set("X-CSRF-Token", token)
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
}
}
// go test -run Test_CSRF_Next

View File

@ -0,0 +1,69 @@
package csrf
import (
"errors"
"github.com/gofiber/fiber/v2"
)
var (
errMissingHeader = errors.New("missing csrf token in header")
errMissingQuery = errors.New("missing csrf token in query")
errMissingParam = errors.New("missing csrf token in param")
errMissingForm = errors.New("missing csrf token in form")
errMissingCookie = errors.New("missing csrf token in cookie")
)
// csrfFromParam returns a function that extracts token from the url param string.
func csrfFromParam(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Params(param)
if token == "" {
return "", errMissingParam
}
return token, nil
}
}
// csrfFromForm returns a function that extracts a token from a multipart-form.
func csrfFromForm(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.FormValue(param)
if token == "" {
return "", errMissingForm
}
return token, nil
}
}
// csrfFromCookie returns a function that extracts token from the cookie header.
func csrfFromCookie(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Cookies(param)
if token == "" {
return "", errMissingCookie
}
return token, nil
}
}
// csrfFromHeader returns a function that extracts token from the request header.
func csrfFromHeader(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Get(param)
if token == "" {
return "", errMissingHeader
}
return token, nil
}
}
// csrfFromQuery returns a function that extracts token from the query string.
func csrfFromQuery(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Query(param)
if token == "" {
return "", errMissingQuery
}
return token, nil
}
}

View File

@ -2,6 +2,7 @@ package favicon
import (
"io/ioutil"
"net/http"
"strconv"
"github.com/gofiber/fiber/v2"
@ -17,12 +18,18 @@ type Config struct {
// File holds the path to an actual favicon that will be cached
//
// Optional. Default: ""
File string
File string `json:"file"`
// FileSystem is an optional alternate filesystem to search for the favicon in.
// An example of this could be an embedded or network filesystem
//
// Optional. Default: nil
FileSystem http.FileSystem `json:"-"`
// CacheControl defines how the Cache-Control header in the response should be set
//
// Optional. Default: "public, max-age=31536000"
CacheControl string
CacheControl string `json:"cache_control"`
}
// ConfigDefault is the default config
@ -66,9 +73,21 @@ func New(config ...Config) fiber.Handler {
iconLen string
)
if cfg.File != "" {
if icon, err = ioutil.ReadFile(cfg.File); err != nil {
panic(err)
// read from configured filesystem if present
if cfg.FileSystem != nil {
f, err := cfg.FileSystem.Open(cfg.File)
if err != nil {
panic(err)
}
if icon, err = ioutil.ReadAll(f); err != nil {
panic(err)
}
} else {
if icon, err = ioutil.ReadFile(cfg.File); err != nil {
panic(err)
}
}
iconLen = strconv.Itoa(len(icon))
}

View File

@ -1,12 +1,16 @@
package favicon
import (
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/valyala/fasthttp"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
// go test -run Test_Middleware_Favicon
@ -71,6 +75,41 @@ func Test_Middleware_Favicon_Found(t *testing.T) {
utils.AssertEqual(t, "public, max-age=31536000", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
}
// mockFS wraps local filesystem for the purposes of
// Test_Middleware_Favicon_FileSystem located below
// TODO use os.Dir if fiber upgrades to 1.16
type mockFS struct {
}
func (m mockFS) Open(name string) (http.File, error) {
if name == "/" {
name = "."
} else {
name = strings.TrimPrefix(name, "/")
}
file, err := os.Open(name)
if err != nil {
return nil, err
}
return file, nil
}
// go test -run Test_Middleware_Favicon_FileSystem
func Test_Middleware_Favicon_FileSystem(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
File: "../../.github/testdata/favicon.ico",
FileSystem: mockFS{},
}))
resp, err := app.Test(httptest.NewRequest("GET", "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
utils.AssertEqual(t, "public, max-age=31536000", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
}
// go test -run Test_Middleware_Favicon_CacheControl
func Test_Middleware_Favicon_CacheControl(t *testing.T) {
app := fiber.New()
@ -79,6 +118,7 @@ func Test_Middleware_Favicon_CacheControl(t *testing.T) {
CacheControl: "public, max-age=100",
File: "../../.github/testdata/favicon.ico",
}))
resp, err := app.Test(httptest.NewRequest("GET", "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")

View File

@ -68,6 +68,7 @@ package main
import (
"embed"
"io/fs"
"log"
"net/http"
@ -75,9 +76,14 @@ import (
"github.com/gofiber/fiber/v2/middleware/filesystem"
)
// Embed a single file
//go:embed index.html
var f embed.FS
// Embed a directory
//go:embed static/*
var embedDirStatic embed.FS
func main() {
app := fiber.New()
@ -85,6 +91,15 @@ func main() {
Root: http.FS(f),
}))
// Access file "image.png" under `static/` directory via URL: `http://<server>/static/image.png`.
// With `http.FS(embedDirStatic)`, you have to access it via URL:
// `http://<server>/static/static/image.png`.
subFS, _ := fs.Sub(embedDirStatic, "static")
app.Use("/static", filesystem.New(filesystem.Config{
Root: http.FS(subFS),
Browse: true,
}))
log.Fatal(app.Listen(":3000"))
}
```

View File

@ -274,6 +274,30 @@ func Test_Session_Cookie(t *testing.T) {
utils.AssertEqual(t, 84, len(ctx.Response().Header.PeekCookie(store.CookieName)))
}
// go test -run Test_Session_Cookie_In_Response
func Test_Session_Cookie_In_Response(t *testing.T) {
t.Parallel()
store := New()
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, _ := store.Get(ctx)
sess.Set("id", "1")
utils.AssertEqual(t, true, sess.Fresh())
sess.Save()
sess, _ = store.Get(ctx)
sess.Set("name", "john")
utils.AssertEqual(t, true, sess.Fresh())
utils.AssertEqual(t, "1", sess.Get("id"))
utils.AssertEqual(t, "john", sess.Get("name"))
}
// go test -v -run=^$ -bench=Benchmark_Session -benchmem -count=4
func Benchmark_Session(b *testing.B) {
app, store := fiber.New(), New()

View File

@ -6,6 +6,8 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/gotiny"
"github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
type Store struct {
@ -36,14 +38,23 @@ func (s *Store) RegisterType(i interface{}) {
// Get will get/create a session
func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
var fresh bool
var loadDada = true
// Get key from cookie
id := c.Cookies(s.CookieName)
if len(id) == 0 {
fresh = true
var err error
if id, err = s.responseCookies(c); err != nil {
return nil, err
}
}
// If no key exist, create new one
if len(id) == 0 {
loadDada = false
id = s.KeyGenerator()
fresh = true
}
// Create session object
@ -54,14 +65,13 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
sess.fresh = fresh
// Fetch existing data
if !fresh {
if loadDada {
raw, err := s.Storage.Get(id)
// Unmashal if we found data
if raw != nil && err == nil {
mux.Lock()
gotiny.Unmarshal(raw, &sess.data)
mux.Unlock()
sess.fresh = false
} else if err != nil {
return nil, err
} else {
@ -72,6 +82,26 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
return sess, nil
}
func (s *Store) responseCookies(c *fiber.Ctx) (string, error) {
// Get key from response cookie
cookieValue := c.Response().Header.PeekCookie(s.CookieName)
if len(cookieValue) == 0 {
return "", nil
}
cookie := fasthttp.AcquireCookie()
err := cookie.ParseBytes(cookieValue)
if err != nil {
return "", err
}
value := make([]byte, len(cookie.Value()))
copy(value, cookie.Value())
id := utils.UnsafeString(value)
fasthttp.ReleaseCookie(cookie)
return id, nil
}
// Reset will delete all session from the storage
func (s *Store) Reset() error {
return s.Storage.Reset()

9
utils/json.go Normal file
View File

@ -0,0 +1,9 @@
package utils
// JSONMarshal returns the JSON encoding of v.
type JSONMarshal func(v interface{}) ([]byte, error)
// JSONUnmarshal parses the JSON-encoded data and stores the result
// in the value pointed to by v. If v is nil or not a pointer,
// Unmarshal returns an InvalidUnmarshalError.
type JSONUnmarshal func(data []byte, v interface{}) error

View File

@ -1,5 +0,0 @@
package utils
// JSONMarshal is the standard definition of representing a Go structure in
// json format
type JSONMarshal func(interface{}) ([]byte, error)

View File

@ -1,26 +0,0 @@
package utils
import (
"encoding/json"
"testing"
)
func TestDefaultJSONEncoder(t *testing.T) {
type SampleStructure struct {
ImportantString string `json:"important_string"`
}
var (
sampleStructure = &SampleStructure{
ImportantString: "Hello World",
}
importantString = `{"important_string":"Hello World"}`
jsonEncoder JSONMarshal = json.Marshal
)
raw, err := jsonEncoder(sampleStructure)
AssertEqual(t, err, nil)
AssertEqual(t, string(raw), importantString)
}

41
utils/json_test.go Normal file
View File

@ -0,0 +1,41 @@
package utils
import (
"encoding/json"
"testing"
)
type sampleStructure struct {
ImportantString string `json:"important_string"`
}
func Test_DefaultJSONEncoder(t *testing.T) {
t.Parallel()
var (
ss = &sampleStructure{
ImportantString: "Hello World",
}
importantString = `{"important_string":"Hello World"}`
jsonEncoder JSONMarshal = json.Marshal
)
raw, err := jsonEncoder(ss)
AssertEqual(t, err, nil)
AssertEqual(t, string(raw), importantString)
}
func Test_DefaultJSONDecoder(t *testing.T) {
t.Parallel()
var (
ss sampleStructure
importantString = []byte(`{"important_string":"Hello World"}`)
jsonDecoder JSONUnmarshal = json.Unmarshal
)
err := jsonDecoder(importantString, &ss)
AssertEqual(t, err, nil)
AssertEqual(t, "Hello World", ss.ImportantString)
}